aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-08-19 17:40:35 -0700
committerReynold Xin <rxin@apache.org>2014-08-19 17:40:35 -0700
commit8b9dc991018842e01f4b93870a2bc2c2cb9ea4ba (patch)
tree446609692cb2f8cad124cd2d7a516960b37f8fcf /core
parent825d4fe47b9c4d48de88622dd48dcf83beb8b80a (diff)
downloadspark-8b9dc991018842e01f4b93870a2bc2c2cb9ea4ba.tar.gz
spark-8b9dc991018842e01f4b93870a2bc2c2cb9ea4ba.tar.bz2
spark-8b9dc991018842e01f4b93870a2bc2c2cb9ea4ba.zip
[SPARK-2468] Netty based block server / client module
Previous pull request (#1907) was reverted. This brings it back. Still looking into the hang. Author: Reynold Xin <rxin@apache.org> Closes #1971 from rxin/netty1 and squashes the following commits: b0be96f [Reynold Xin] Added test to make sure outstandingRequests are cleaned after firing the events. 4c6d0ee [Reynold Xin] Pass callbacks cleanly. 603dce7 [Reynold Xin] Upgrade Netty to 4.0.23 to fix the DefaultFileRegion bug. 88be1d4 [Reynold Xin] Downgrade to 4.0.21 to work around a bug in writing DefaultFileRegion. 002626a [Reynold Xin] Remove netty-test-file.txt. db6e6e0 [Reynold Xin] Revert "Revert "[SPARK-2468] Netty based block server / client module""
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/FileClient.scala85
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala50
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala71
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/FileServer.scala91
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala68
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala59
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala118
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala71
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala (renamed from core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala)16
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala132
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala99
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala103
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala47
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala47
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala162
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala (renamed from core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala)22
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala140
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala137
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala13
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala161
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala105
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala64
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala107
28 files changed, 1483 insertions, 663 deletions
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala b/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala
deleted file mode 100644
index c6d35f73db..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.util.concurrent.TimeUnit
-
-import io.netty.bootstrap.Bootstrap
-import io.netty.channel.{Channel, ChannelOption, EventLoopGroup}
-import io.netty.channel.oio.OioEventLoopGroup
-import io.netty.channel.socket.oio.OioSocketChannel
-
-import org.apache.spark.Logging
-
-class FileClient(handler: FileClientHandler, connectTimeout: Int) extends Logging {
-
- private var channel: Channel = _
- private var bootstrap: Bootstrap = _
- private var group: EventLoopGroup = _
- private val sendTimeout = 60
-
- def init(): Unit = {
- group = new OioEventLoopGroup
- bootstrap = new Bootstrap
- bootstrap.group(group)
- .channel(classOf[OioSocketChannel])
- .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
- .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
- .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Integer.valueOf(connectTimeout))
- .handler(new FileClientChannelInitializer(handler))
- }
-
- def connect(host: String, port: Int) {
- try {
- channel = bootstrap.connect(host, port).sync().channel()
- } catch {
- case e: InterruptedException =>
- logWarning("FileClient interrupted while trying to connect", e)
- close()
- }
- }
-
- def waitForClose(): Unit = {
- try {
- channel.closeFuture.sync()
- } catch {
- case e: InterruptedException =>
- logWarning("FileClient interrupted", e)
- }
- }
-
- def sendRequest(file: String): Unit = {
- try {
- val bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS)
- if (!bSent) {
- throw new RuntimeException("Failed to send")
- }
- } catch {
- case e: InterruptedException =>
- logError("Error", e)
- }
- }
-
- def close(): Unit = {
- if (group != null) {
- group.shutdownGracefully()
- group = null
- bootstrap = null
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala
deleted file mode 100644
index 017302ec7d..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import io.netty.buffer.ByteBuf
-import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
-
-import org.apache.spark.storage.BlockId
-
-
-abstract class FileClientHandler extends SimpleChannelInboundHandler[ByteBuf] {
-
- private var currentHeader: FileHeader = null
-
- @volatile
- private var handlerCalled: Boolean = false
-
- def isComplete: Boolean = handlerCalled
-
- def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader)
-
- def handleError(blockId: BlockId)
-
- override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) {
- if (currentHeader == null && in.readableBytes >= FileHeader.HEADER_SIZE) {
- currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE))
- }
- if (in.readableBytes >= currentHeader.fileLen) {
- handle(ctx, in, currentHeader)
- handlerCalled = true
- currentHeader = null
- ctx.close()
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
deleted file mode 100644
index 607e560ff2..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import io.netty.buffer._
-
-import org.apache.spark.Logging
-import org.apache.spark.storage.{BlockId, TestBlockId}
-
-private[spark] class FileHeader (
- val fileLen: Int,
- val blockId: BlockId) extends Logging {
-
- lazy val buffer: ByteBuf = {
- val buf = Unpooled.buffer()
- buf.capacity(FileHeader.HEADER_SIZE)
- buf.writeInt(fileLen)
- buf.writeInt(blockId.name.length)
- blockId.name.foreach((x: Char) => buf.writeByte(x))
- // padding the rest of header
- if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
- buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
- } else {
- throw new Exception("too long header " + buf.readableBytes)
- logInfo("too long header")
- }
- buf
- }
-
-}
-
-private[spark] object FileHeader {
-
- val HEADER_SIZE = 40
-
- def getFileLenOffset = 0
- def getFileLenSize = Integer.SIZE/8
-
- def create(buf: ByteBuf): FileHeader = {
- val length = buf.readInt
- val idLength = buf.readInt
- val idBuilder = new StringBuilder(idLength)
- for (i <- 1 to idLength) {
- idBuilder += buf.readByte().asInstanceOf[Char]
- }
- val blockId = BlockId(idBuilder.toString())
- new FileHeader(length, blockId)
- }
-
- def main(args:Array[String]) {
- val header = new FileHeader(25, TestBlockId("my_block"))
- val buf = header.buffer
- val newHeader = FileHeader.create(buf)
- System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala
deleted file mode 100644
index dff7795065..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.net.InetSocketAddress
-
-import io.netty.bootstrap.ServerBootstrap
-import io.netty.channel.{ChannelFuture, ChannelOption, EventLoopGroup}
-import io.netty.channel.oio.OioEventLoopGroup
-import io.netty.channel.socket.oio.OioServerSocketChannel
-
-import org.apache.spark.Logging
-
-/**
- * Server that accept the path of a file an echo back its content.
- */
-class FileServer(pResolver: PathResolver, private var port: Int) extends Logging {
-
- private val addr: InetSocketAddress = new InetSocketAddress(port)
- private var bossGroup: EventLoopGroup = new OioEventLoopGroup
- private var workerGroup: EventLoopGroup = new OioEventLoopGroup
-
- private var channelFuture: ChannelFuture = {
- val bootstrap = new ServerBootstrap
- bootstrap.group(bossGroup, workerGroup)
- .channel(classOf[OioServerSocketChannel])
- .option(ChannelOption.SO_BACKLOG, java.lang.Integer.valueOf(100))
- .option(ChannelOption.SO_RCVBUF, java.lang.Integer.valueOf(1500))
- .childHandler(new FileServerChannelInitializer(pResolver))
- bootstrap.bind(addr)
- }
-
- try {
- val boundAddress = channelFuture.sync.channel.localAddress.asInstanceOf[InetSocketAddress]
- port = boundAddress.getPort
- } catch {
- case ie: InterruptedException =>
- port = 0
- }
-
- /** Start the file server asynchronously in a new thread. */
- def start(): Unit = {
- val blockingThread: Thread = new Thread {
- override def run(): Unit = {
- try {
- channelFuture.channel.closeFuture.sync
- logInfo("FileServer exiting")
- } catch {
- case e: InterruptedException =>
- logError("File server start got interrupted", e)
- }
- // NOTE: bootstrap is shutdown in stop()
- }
- }
- blockingThread.setDaemon(true)
- blockingThread.start()
- }
-
- def getPort: Int = port
-
- def stop(): Unit = {
- if (channelFuture != null) {
- channelFuture.channel().close().awaitUninterruptibly()
- channelFuture = null
- }
- if (bossGroup != null) {
- bossGroup.shutdownGracefully()
- bossGroup = null
- }
- if (workerGroup != null) {
- workerGroup.shutdownGracefully()
- workerGroup = null
- }
- }
-}
-
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala
deleted file mode 100644
index 96f60b2883..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.io.FileInputStream
-
-import io.netty.channel.{DefaultFileRegion, ChannelHandlerContext, SimpleChannelInboundHandler}
-
-import org.apache.spark.Logging
-import org.apache.spark.storage.{BlockId, FileSegment}
-
-
-class FileServerHandler(pResolver: PathResolver)
- extends SimpleChannelInboundHandler[String] with Logging {
-
- override def channelRead0(ctx: ChannelHandlerContext, blockIdString: String): Unit = {
- val blockId: BlockId = BlockId(blockIdString)
- val fileSegment: FileSegment = pResolver.getBlockLocation(blockId)
- if (fileSegment == null) {
- return
- }
- val file = fileSegment.file
- if (file.exists) {
- if (!file.isFile) {
- ctx.write(new FileHeader(0, blockId).buffer)
- ctx.flush()
- return
- }
- val length: Long = fileSegment.length
- if (length > Integer.MAX_VALUE || length <= 0) {
- ctx.write(new FileHeader(0, blockId).buffer)
- ctx.flush()
- return
- }
- ctx.write(new FileHeader(length.toInt, blockId).buffer)
- try {
- val channel = new FileInputStream(file).getChannel
- ctx.write(new DefaultFileRegion(channel, fileSegment.offset, fileSegment.length))
- } catch {
- case e: Exception =>
- logError("Exception: ", e)
- }
- } else {
- ctx.write(new FileHeader(0, blockId).buffer)
- }
- ctx.flush()
- }
-
- override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- logError("Exception: ", cause)
- ctx.close()
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala
new file mode 100644
index 0000000000..b5870152c5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import org.apache.spark.SparkConf
+
+/**
+ * A central location that tracks all the settings we exposed to users.
+ */
+private[spark]
+class NettyConfig(conf: SparkConf) {
+
+ /** Port the server listens on. Default to a random port. */
+ private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0)
+
+ /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */
+ private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase
+
+ /** Connect timeout in secs. Default 60 secs. */
+ private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000
+
+ /**
+ * Percentage of the desired amount of time spent for I/O in the child event loops.
+ * Only applicable in nio and epoll.
+ */
+ private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80)
+
+ /** Requested maximum length of the queue of incoming connections. */
+ private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt)
+
+ /**
+ * Receive buffer size (SO_RCVBUF).
+ * Note: the optimal size for receive buffer and send buffer should be
+ * latency * network_bandwidth.
+ * Assuming latency = 1ms, network_bandwidth = 10Gbps
+ * buffer size should be ~ 1.25MB
+ */
+ private[netty] val receiveBuf: Option[Int] =
+ conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
+
+ /** Send buffer size (SO_SNDBUF). */
+ private[netty] val sendBuf: Option[Int] =
+ conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
deleted file mode 100644
index e7b2855e1e..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
+++ /dev/null
@@ -1,118 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.util.concurrent.Executors
-
-import scala.collection.JavaConverters._
-
-import io.netty.buffer.ByteBuf
-import io.netty.channel.ChannelHandlerContext
-import io.netty.util.CharsetUtil
-
-import org.apache.spark.{Logging, SparkConf}
-import org.apache.spark.network.ConnectionManagerId
-import org.apache.spark.storage.BlockId
-
-private[spark] class ShuffleCopier(conf: SparkConf) extends Logging {
-
- def getBlock(host: String, port: Int, blockId: BlockId,
- resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
-
- val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
- val connectTimeout = conf.getInt("spark.shuffle.netty.connect.timeout", 60000)
- val fc = new FileClient(handler, connectTimeout)
-
- try {
- fc.init()
- fc.connect(host, port)
- fc.sendRequest(blockId.name)
- fc.waitForClose()
- fc.close()
- } catch {
- // Handle any socket-related exceptions in FileClient
- case e: Exception => {
- logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e)
- handler.handleError(blockId)
- }
- }
- }
-
- def getBlock(cmId: ConnectionManagerId, blockId: BlockId,
- resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
- getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
- }
-
- def getBlocks(cmId: ConnectionManagerId,
- blocks: Seq[(BlockId, Long)],
- resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
-
- for ((blockId, size) <- blocks) {
- getBlock(cmId, blockId, resultCollectCallback)
- }
- }
-}
-
-
-private[spark] object ShuffleCopier extends Logging {
-
- private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit)
- extends FileClientHandler with Logging {
-
- override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
- logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)")
- resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
- }
-
- override def handleError(blockId: BlockId) {
- if (!isComplete) {
- resultCollectCallBack(blockId, -1, null)
- }
- }
- }
-
- def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) {
- if (size != -1) {
- logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
- }
- }
-
- def main(args: Array[String]) {
- if (args.length < 3) {
- System.err.println("Usage: ShuffleCopier <host> <port> <shuffle_block_id> <threads>")
- System.exit(1)
- }
- val host = args(0)
- val port = args(1).toInt
- val blockId = BlockId(args(2))
- val threads = if (args.length > 3) args(3).toInt else 10
-
- val copiers = Executors.newFixedThreadPool(80)
- val tasks = (for (i <- Range(0, threads)) yield {
- Executors.callable(new Runnable() {
- def run() {
- val copier = new ShuffleCopier(new SparkConf)
- copier.getBlock(host, port, blockId, echoResultCollectCallBack)
- }
- })
- }).asJava
- copiers.invokeAll(tasks)
- copiers.shutdown()
- System.exit(0)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
deleted file mode 100644
index 95958e30f7..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.io.File
-
-import org.apache.spark.Logging
-import org.apache.spark.util.Utils
-import org.apache.spark.storage.{BlockId, FileSegment}
-
-private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
-
- val server = new FileServer(pResolver, portIn)
- server.start()
-
- def stop() {
- server.stop()
- }
-
- def port: Int = server.getPort
-}
-
-
-/**
- * An application for testing the shuffle sender as a standalone program.
- */
-private[spark] object ShuffleSender {
-
- def main(args: Array[String]) {
- if (args.length < 3) {
- System.err.println(
- "Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>")
- System.exit(1)
- }
-
- val port = args(0).toInt
- val subDirsPerLocalDir = args(1).toInt
- val localDirs = args.drop(2).map(new File(_))
-
- val pResovler = new PathResolver {
- override def getBlockLocation(blockId: BlockId): FileSegment = {
- if (!blockId.isShuffle) {
- throw new Exception("Block " + blockId + " is not a shuffle block")
- }
- // Figure out which local directory it hashes to, and which subdirectory in that
- val hash = Utils.nonNegativeHash(blockId)
- val dirId = hash % localDirs.length
- val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
- val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
- val file = new File(subDir, blockId.name)
- new FileSegment(file, 0, file.length())
- }
- }
- val sender = new ShuffleSender(port, pResovler)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala
index f4261c13f7..e28219dd77 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala
@@ -15,17 +15,15 @@
* limitations under the License.
*/
-package org.apache.spark.network.netty
+package org.apache.spark.network.netty.client
-import io.netty.channel.ChannelInitializer
-import io.netty.channel.socket.SocketChannel
-import io.netty.handler.codec.string.StringEncoder
+import java.util.EventListener
-class FileClientChannelInitializer(handler: FileClientHandler)
- extends ChannelInitializer[SocketChannel] {
+trait BlockClientListener extends EventListener {
+
+ def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit
+
+ def onFetchFailure(blockId: String, errorMsg: String): Unit
- def initChannel(channel: SocketChannel) {
- channel.pipeline.addLast("encoder", new StringEncoder).addLast("handler", handler)
- }
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala
new file mode 100644
index 0000000000..5aea7ba2f3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.client
+
+import java.util.concurrent.TimeoutException
+
+import io.netty.bootstrap.Bootstrap
+import io.netty.buffer.PooledByteBufAllocator
+import io.netty.channel.socket.SocketChannel
+import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption}
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder
+import io.netty.handler.codec.string.StringEncoder
+import io.netty.util.CharsetUtil
+
+import org.apache.spark.Logging
+
+/**
+ * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]].
+ * Use [[BlockFetchingClientFactory]] to instantiate this client.
+ *
+ * The constructor blocks until a connection is successfully established.
+ *
+ * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol.
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+@throws[TimeoutException]
+private[spark]
+class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int)
+ extends Logging {
+
+ private val handler = new BlockFetchingClientHandler
+
+ /** Netty Bootstrap for creating the TCP connection. */
+ private val bootstrap: Bootstrap = {
+ val b = new Bootstrap
+ b.group(factory.workerGroup)
+ .channel(factory.socketChannelClass)
+ // Use pooled buffers to reduce temporary buffer allocation
+ .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+ // Disable Nagle's Algorithm since we don't want packets to wait
+ .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
+ .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
+ .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs)
+
+ b.handler(new ChannelInitializer[SocketChannel] {
+ override def initChannel(ch: SocketChannel): Unit = {
+ ch.pipeline
+ .addLast("encoder", new StringEncoder(CharsetUtil.UTF_8))
+ // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4
+ .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4))
+ .addLast("handler", handler)
+ }
+ })
+ b
+ }
+
+ /** Netty ChannelFuture for the connection. */
+ private val cf: ChannelFuture = bootstrap.connect(hostname, port)
+ if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) {
+ throw new TimeoutException(
+ s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)")
+ }
+
+ /**
+ * Ask the remote server for a sequence of blocks, and execute the callback.
+ *
+ * Note that this is asynchronous and returns immediately. Upstream caller should throttle the
+ * rate of fetching; otherwise we could run out of memory.
+ *
+ * @param blockIds sequence of block ids to fetch.
+ * @param listener callback to fire on fetch success / failure.
+ */
+ def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = {
+ // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline.
+ // It's also best to limit the number of "flush" calls since it requires system calls.
+ // Let's concatenate the string and then call writeAndFlush once.
+ // This is also why this implementation might be more efficient than multiple, separate
+ // fetch block calls.
+ var startTime: Long = 0
+ logTrace {
+ startTime = System.nanoTime
+ s"Sending request $blockIds to $hostname:$port"
+ }
+
+ blockIds.foreach { blockId =>
+ handler.addRequest(blockId, listener)
+ }
+
+ val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n")
+ writeFuture.addListener(new ChannelFutureListener {
+ override def operationComplete(future: ChannelFuture): Unit = {
+ if (future.isSuccess) {
+ logTrace {
+ val timeTaken = (System.nanoTime - startTime).toDouble / 1000000
+ s"Sending request $blockIds to $hostname:$port took $timeTaken ms"
+ }
+ } else {
+ // Fail all blocks.
+ val errorMsg =
+ s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}"
+ logError(errorMsg, future.cause)
+ blockIds.foreach { blockId =>
+ listener.onFetchFailure(blockId, errorMsg)
+ handler.removeRequest(blockId)
+ }
+ }
+ }
+ })
+ }
+
+ def waitForClose(): Unit = {
+ cf.channel().closeFuture().sync()
+ }
+
+ def close(): Unit = cf.channel().close()
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala
new file mode 100644
index 0000000000..2b28402c52
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.client
+
+import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel}
+import io.netty.channel.nio.NioEventLoopGroup
+import io.netty.channel.oio.OioEventLoopGroup
+import io.netty.channel.socket.nio.NioSocketChannel
+import io.netty.channel.socket.oio.OioSocketChannel
+import io.netty.channel.{EventLoopGroup, Channel}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network.netty.NettyConfig
+import org.apache.spark.util.Utils
+
+/**
+ * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses
+ * the worker thread pool for Netty.
+ *
+ * Concurrency: createClient is safe to be called from multiple threads concurrently.
+ */
+private[spark]
+class BlockFetchingClientFactory(val conf: NettyConfig) {
+
+ def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf))
+
+ /** A thread factory so the threads are named (for debugging). */
+ val threadFactory = Utils.namedThreadFactory("spark-shuffle-client")
+
+ /** The following two are instantiated by the [[init]] method, depending ioMode. */
+ var socketChannelClass: Class[_ <: Channel] = _
+ var workerGroup: EventLoopGroup = _
+
+ init()
+
+ /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */
+ private def init(): Unit = {
+ def initOio(): Unit = {
+ socketChannelClass = classOf[OioSocketChannel]
+ workerGroup = new OioEventLoopGroup(0, threadFactory)
+ }
+ def initNio(): Unit = {
+ socketChannelClass = classOf[NioSocketChannel]
+ workerGroup = new NioEventLoopGroup(0, threadFactory)
+ }
+ def initEpoll(): Unit = {
+ socketChannelClass = classOf[EpollSocketChannel]
+ workerGroup = new EpollEventLoopGroup(0, threadFactory)
+ }
+
+ conf.ioMode match {
+ case "nio" => initNio()
+ case "oio" => initOio()
+ case "epoll" => initEpoll()
+ case "auto" =>
+ // For auto mode, first try epoll (only available on Linux), then nio.
+ try {
+ initEpoll()
+ } catch {
+ // TODO: Should we log the throwable? But that always happen on non-Linux systems.
+ // Perhaps the right thing to do is to check whether the system is Linux, and then only
+ // call initEpoll on Linux.
+ case e: Throwable => initNio()
+ }
+ }
+ }
+
+ /**
+ * Create a new BlockFetchingClient connecting to the given remote host / port.
+ *
+ * This blocks until a connection is successfully established.
+ *
+ * Concurrency: This method is safe to call from multiple threads.
+ */
+ def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = {
+ new BlockFetchingClient(this, remoteHost, remotePort)
+ }
+
+ def stop(): Unit = {
+ if (workerGroup != null) {
+ workerGroup.shutdownGracefully()
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala
new file mode 100644
index 0000000000..83265b1642
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.client
+
+import io.netty.buffer.ByteBuf
+import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
+
+import org.apache.spark.Logging
+
+
+/**
+ * Handler that processes server responses. It uses the protocol documented in
+ * [[org.apache.spark.network.netty.server.BlockServer]].
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+private[client]
+class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging {
+
+ /** Tracks the list of outstanding requests and their listeners on success/failure. */
+ private val outstandingRequests = java.util.Collections.synchronizedMap {
+ new java.util.HashMap[String, BlockClientListener]
+ }
+
+ def addRequest(blockId: String, listener: BlockClientListener): Unit = {
+ outstandingRequests.put(blockId, listener)
+ }
+
+ def removeRequest(blockId: String): Unit = {
+ outstandingRequests.remove(blockId)
+ }
+
+ override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
+ val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}"
+ logError(errorMsg, cause)
+
+ // Fire the failure callback for all outstanding blocks
+ outstandingRequests.synchronized {
+ val iter = outstandingRequests.entrySet().iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ entry.getValue.onFetchFailure(entry.getKey, errorMsg)
+ }
+ outstandingRequests.clear()
+ }
+
+ ctx.close()
+ }
+
+ override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) {
+ val totalLen = in.readInt()
+ val blockIdLen = in.readInt()
+ val blockIdBytes = new Array[Byte](math.abs(blockIdLen))
+ in.readBytes(blockIdBytes)
+ val blockId = new String(blockIdBytes)
+ val blockSize = totalLen - math.abs(blockIdLen) - 4
+
+ def server = ctx.channel.remoteAddress.toString
+
+ // blockIdLen is negative when it is an error message.
+ if (blockIdLen < 0) {
+ val errorMessageBytes = new Array[Byte](blockSize)
+ in.readBytes(errorMessageBytes)
+ val errorMsg = new String(errorMessageBytes)
+ logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server")
+
+ val listener = outstandingRequests.get(blockId)
+ if (listener == null) {
+ // Ignore callback
+ logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
+ } else {
+ outstandingRequests.remove(blockId)
+ listener.onFetchFailure(blockId, errorMsg)
+ }
+ } else {
+ logTrace(s"Received block $blockId ($blockSize B) from $server")
+
+ val listener = outstandingRequests.get(blockId)
+ if (listener == null) {
+ // Ignore callback
+ logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
+ } else {
+ outstandingRequests.remove(blockId)
+ listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in))
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala
new file mode 100644
index 0000000000..9740ee64d1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.client
+
+/**
+ * A simple iterator that lazily initializes the underlying iterator.
+ *
+ * The use case is that sometimes we might have many iterators open at the same time, and each of
+ * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer).
+ * This could lead to too many buffers open. If this iterator is used, we lazily initialize those
+ * buffers.
+ */
+private[spark]
+class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] {
+
+ lazy val proxy = createIterator
+
+ override def hasNext: Boolean = {
+ val gotNext = proxy.hasNext
+ if (!gotNext) {
+ close()
+ }
+ gotNext
+ }
+
+ override def next(): Any = proxy.next()
+
+ def close(): Unit = Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala
new file mode 100644
index 0000000000..ea1abf5ecc
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.client
+
+import java.io.InputStream
+import java.nio.ByteBuffer
+
+import io.netty.buffer.{ByteBuf, ByteBufInputStream}
+
+
+/**
+ * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty.
+ * This is a Scala value class.
+ *
+ * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of
+ * reference by the retain method and release method.
+ */
+private[spark]
+class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal {
+
+ /** Return the nio ByteBuffer view of the underlying buffer. */
+ def byteBuffer(): ByteBuffer = underlying.nioBuffer
+
+ /** Creates a new input stream that starts from the current position of the buffer. */
+ def inputStream(): InputStream = new ByteBufInputStream(underlying)
+
+ /** Increment the reference counter by one. */
+ def retain(): Unit = underlying.retain()
+
+ /** Decrement the reference counter by one and release the buffer if the ref count is 0. */
+ def release(): Unit = underlying.release()
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala
new file mode 100644
index 0000000000..162e9cc682
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.server
+
+/**
+ * Header describing a block. This is used only in the server pipeline.
+ *
+ * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it.
+ *
+ * @param blockSize length of the block content, excluding the length itself.
+ * If positive, this is the header for a block (not part of the header).
+ * If negative, this is the header and content for an error message.
+ * @param blockId block id
+ * @param error some error message from reading the block
+ */
+private[server]
+class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None)
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala
new file mode 100644
index 0000000000..8e4dda4ef8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.server
+
+import io.netty.buffer.ByteBuf
+import io.netty.channel.ChannelHandlerContext
+import io.netty.handler.codec.MessageToByteEncoder
+
+/**
+ * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol.
+ */
+private[server]
+class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] {
+ override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = {
+ // message = message length (4 bytes) + block id length (4 bytes) + block id + block data
+ // message length = block id length (4 bytes) + size of block id + size of block data
+ val blockIdBytes = msg.blockId.getBytes
+ msg.error match {
+ case Some(errorMsg) =>
+ val errorBytes = errorMsg.getBytes
+ out.writeInt(4 + blockIdBytes.length + errorBytes.size)
+ out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors
+ out.writeBytes(blockIdBytes) // next is blockId itself
+ out.writeBytes(errorBytes) // error message
+ case None =>
+ out.writeInt(4 + blockIdBytes.length + msg.blockSize)
+ out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length
+ out.writeBytes(blockIdBytes) // next is blockId itself
+ // msg of size blockSize will be written by ServerHandler
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala
new file mode 100644
index 0000000000..7b2f9a8d4d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.server
+
+import java.net.InetSocketAddress
+
+import io.netty.bootstrap.ServerBootstrap
+import io.netty.buffer.PooledByteBufAllocator
+import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption}
+import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel}
+import io.netty.channel.nio.NioEventLoopGroup
+import io.netty.channel.oio.OioEventLoopGroup
+import io.netty.channel.socket.SocketChannel
+import io.netty.channel.socket.nio.NioServerSocketChannel
+import io.netty.channel.socket.oio.OioServerSocketChannel
+import io.netty.handler.codec.LineBasedFrameDecoder
+import io.netty.handler.codec.string.StringDecoder
+import io.netty.util.CharsetUtil
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.network.netty.NettyConfig
+import org.apache.spark.storage.BlockDataProvider
+import org.apache.spark.util.Utils
+
+
+/**
+ * Server for serving Spark data blocks.
+ * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]].
+ *
+ * Protocol for requesting blocks (client to server):
+ * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n"
+ *
+ * Protocol for sending blocks (server to client):
+ * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data.
+ *
+ * frame-length should not include the length of itself.
+ * If block-id-length is negative, then this is an error message rather than block-data. The real
+ * length is the absolute value of the frame-length.
+ *
+ */
+private[spark]
+class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging {
+
+ def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = {
+ this(new NettyConfig(sparkConf), dataProvider)
+ }
+
+ def port: Int = _port
+
+ def hostName: String = _hostName
+
+ private var _port: Int = conf.serverPort
+ private var _hostName: String = ""
+ private var bootstrap: ServerBootstrap = _
+ private var channelFuture: ChannelFuture = _
+
+ init()
+
+ /** Initialize the server. */
+ private def init(): Unit = {
+ bootstrap = new ServerBootstrap
+ val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss")
+ val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker")
+
+ // Use only one thread to accept connections, and 2 * num_cores for worker.
+ def initNio(): Unit = {
+ val bossGroup = new NioEventLoopGroup(1, bossThreadFactory)
+ val workerGroup = new NioEventLoopGroup(0, workerThreadFactory)
+ workerGroup.setIoRatio(conf.ioRatio)
+ bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel])
+ }
+ def initOio(): Unit = {
+ val bossGroup = new OioEventLoopGroup(1, bossThreadFactory)
+ val workerGroup = new OioEventLoopGroup(0, workerThreadFactory)
+ bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel])
+ }
+ def initEpoll(): Unit = {
+ val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory)
+ val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory)
+ workerGroup.setIoRatio(conf.ioRatio)
+ bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel])
+ }
+
+ conf.ioMode match {
+ case "nio" => initNio()
+ case "oio" => initOio()
+ case "epoll" => initEpoll()
+ case "auto" =>
+ // For auto mode, first try epoll (only available on Linux), then nio.
+ try {
+ initEpoll()
+ } catch {
+ // TODO: Should we log the throwable? But that always happen on non-Linux systems.
+ // Perhaps the right thing to do is to check whether the system is Linux, and then only
+ // call initEpoll on Linux.
+ case e: Throwable => initNio()
+ }
+ }
+
+ // Use pooled buffers to reduce temporary buffer allocation
+ bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+ bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+
+ // Various (advanced) user-configured settings.
+ conf.backLog.foreach { backLog =>
+ bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog)
+ }
+ conf.receiveBuf.foreach { receiveBuf =>
+ bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf)
+ }
+ conf.sendBuf.foreach { sendBuf =>
+ bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf)
+ }
+
+ bootstrap.childHandler(new ChannelInitializer[SocketChannel] {
+ override def initChannel(ch: SocketChannel): Unit = {
+ ch.pipeline
+ .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
+ .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8))
+ .addLast("blockHeaderEncoder", new BlockHeaderEncoder)
+ .addLast("handler", new BlockServerHandler(dataProvider))
+ }
+ })
+
+ channelFuture = bootstrap.bind(new InetSocketAddress(_port))
+ channelFuture.sync()
+
+ val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress]
+ _port = addr.getPort
+ _hostName = addr.getHostName
+ }
+
+ /** Shutdown the server. */
+ def stop(): Unit = {
+ if (channelFuture != null) {
+ channelFuture.channel().close().awaitUninterruptibly()
+ channelFuture = null
+ }
+ if (bootstrap != null && bootstrap.group() != null) {
+ bootstrap.group().shutdownGracefully()
+ }
+ if (bootstrap != null && bootstrap.childGroup() != null) {
+ bootstrap.childGroup().shutdownGracefully()
+ }
+ bootstrap = null
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala
index aaa2f913d0..cc70bd0c5c 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala
@@ -15,20 +15,26 @@
* limitations under the License.
*/
-package org.apache.spark.network.netty
+package org.apache.spark.network.netty.server
import io.netty.channel.ChannelInitializer
import io.netty.channel.socket.SocketChannel
-import io.netty.handler.codec.{DelimiterBasedFrameDecoder, Delimiters}
+import io.netty.handler.codec.LineBasedFrameDecoder
import io.netty.handler.codec.string.StringDecoder
+import io.netty.util.CharsetUtil
+import org.apache.spark.storage.BlockDataProvider
-class FileServerChannelInitializer(pResolver: PathResolver)
+
+/** Channel initializer that sets up the pipeline for the BlockServer. */
+private[netty]
+class BlockServerChannelInitializer(dataProvider: BlockDataProvider)
extends ChannelInitializer[SocketChannel] {
- override def initChannel(channel: SocketChannel): Unit = {
- channel.pipeline
- .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter : _*))
- .addLast("stringDecoder", new StringDecoder)
- .addLast("handler", new FileServerHandler(pResolver))
+ override def initChannel(ch: SocketChannel): Unit = {
+ ch.pipeline
+ .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
+ .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8))
+ .addLast("blockHeaderEncoder", new BlockHeaderEncoder)
+ .addLast("handler", new BlockServerHandler(dataProvider))
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala
new file mode 100644
index 0000000000..40dd5e5d1a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.server
+
+import java.io.FileInputStream
+import java.nio.ByteBuffer
+import java.nio.channels.FileChannel
+
+import io.netty.buffer.Unpooled
+import io.netty.channel._
+
+import org.apache.spark.Logging
+import org.apache.spark.storage.{FileSegment, BlockDataProvider}
+
+
+/**
+ * A handler that processes requests from clients and writes block data back.
+ *
+ * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first
+ * so channelRead0 is called once per line (i.e. per block id).
+ */
+private[server]
+class BlockServerHandler(dataProvider: BlockDataProvider)
+ extends SimpleChannelInboundHandler[String] with Logging {
+
+ override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
+ logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause)
+ ctx.close()
+ }
+
+ override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = {
+ def client = ctx.channel.remoteAddress.toString
+
+ // A helper function to send error message back to the client.
+ def respondWithError(error: String): Unit = {
+ ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener(
+ new ChannelFutureListener {
+ override def operationComplete(future: ChannelFuture) {
+ if (!future.isSuccess) {
+ // TODO: Maybe log the success case as well.
+ logError(s"Error sending error back to $client", future.cause)
+ ctx.close()
+ }
+ }
+ }
+ )
+ }
+
+ def writeFileSegment(segment: FileSegment): Unit = {
+ // Send error message back if the block is too large. Even though we are capable of sending
+ // large (2G+) blocks, the receiving end cannot handle it so let's fail fast.
+ // Once we fixed the receiving end to be able to process large blocks, this should be removed.
+ // Also make sure we update BlockHeaderEncoder to support length > 2G.
+
+ // See [[BlockHeaderEncoder]] for the way length is encoded.
+ if (segment.length + blockId.length + 4 > Int.MaxValue) {
+ respondWithError(s"Block $blockId size ($segment.length) greater than 2G")
+ return
+ }
+
+ var fileChannel: FileChannel = null
+ try {
+ fileChannel = new FileInputStream(segment.file).getChannel
+ } catch {
+ case e: Exception =>
+ logError(
+ s"Error opening channel for $blockId in ${segment.file} for request from $client", e)
+ respondWithError(e.getMessage)
+ }
+
+ // Found the block. Send it back.
+ if (fileChannel != null) {
+ // Write the header and block data. In the case of failures, the listener on the block data
+ // write should close the connection.
+ ctx.write(new BlockHeader(segment.length.toInt, blockId))
+
+ val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length)
+ ctx.writeAndFlush(region).addListener(new ChannelFutureListener {
+ override def operationComplete(future: ChannelFuture) {
+ if (future.isSuccess) {
+ logTrace(s"Sent block $blockId (${segment.length} B) back to $client")
+ } else {
+ logError(s"Error sending block $blockId to $client; closing connection", future.cause)
+ ctx.close()
+ }
+ }
+ })
+ }
+ }
+
+ def writeByteBuffer(buf: ByteBuffer): Unit = {
+ ctx.write(new BlockHeader(buf.remaining, blockId))
+ ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener {
+ override def operationComplete(future: ChannelFuture) {
+ if (future.isSuccess) {
+ logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client")
+ } else {
+ logError(s"Error sending block $blockId to $client; closing connection", future.cause)
+ ctx.close()
+ }
+ }
+ })
+ }
+
+ logTrace(s"Received request from $client to fetch block $blockId")
+
+ var blockData: Either[FileSegment, ByteBuffer] = null
+
+ // First make sure we can find the block. If not, send error back to the user.
+ try {
+ blockData = dataProvider.getBlockData(blockId)
+ } catch {
+ case e: Exception =>
+ logError(s"Error opening block $blockId for request from $client", e)
+ respondWithError(e.getMessage)
+ return
+ }
+
+ blockData match {
+ case Left(segment) => writeFileSegment(segment)
+ case Right(buf) => writeByteBuffer(buf)
+ }
+
+ } // end of channelRead0
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala
new file mode 100644
index 0000000000..5b6d086630
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+
+
+/**
+ * An interface for providing data for blocks.
+ *
+ * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer.
+ *
+ * Aside from unit tests, [[BlockManager]] is the main class that implements this.
+ */
+private[spark] trait BlockDataProvider {
+ def getBlockData(blockId: String): Either[FileSegment, ByteBuffer]
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index 5f44f5f319..ca60ec78b6 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -18,19 +18,17 @@
package org.apache.spark.storage
import java.util.concurrent.LinkedBlockingQueue
+import org.apache.spark.network.netty.client.{BlockClientListener, LazyInitIterator, ReferenceCountedBuffer}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashSet
import scala.collection.mutable.Queue
import scala.util.{Failure, Success}
-import io.netty.buffer.ByteBuf
-
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.executor.ShuffleReadMetrics
import org.apache.spark.network.BufferMessage
import org.apache.spark.network.ConnectionManagerId
-import org.apache.spark.network.netty.ShuffleCopier
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
@@ -54,18 +52,28 @@ trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] wi
private[storage]
object BlockFetcherIterator {
- // A request to fetch one or more blocks, complete with their sizes
+ /**
+ * A request to fetch blocks from a remote BlockManager.
+ * @param address remote BlockManager to fetch from.
+ * @param blocks Sequence of tuple, where the first element is the block id,
+ * and the second element is the estimated size, used to calculate bytesInFlight.
+ */
class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
val size = blocks.map(_._2).sum
}
- // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
- // the block (since we want all deserializaton to happen in the calling thread); can also
- // represent a fetch failure if size == -1.
+ /**
+ * Result of a fetch from a remote block. A failure is represented as size == -1.
+ * @param blockId block id
+ * @param size estimated size of the block, used to calculate bytesInFlight.
+ * Note that this is NOT the exact bytes.
+ * @param deserialize closure to return the result in the form of an Iterator.
+ */
class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
def failed: Boolean = size == -1
}
+ // TODO: Refactor this whole thing to make code more reusable.
class BasicBlockFetcherIterator(
private val blockManager: BlockManager,
val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
@@ -95,10 +103,10 @@ object BlockFetcherIterator {
// Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
// the number of bytes in flight is limited to maxBytesInFlight
- private val fetchRequests = new Queue[FetchRequest]
+ protected val fetchRequests = new Queue[FetchRequest]
// Current bytes in flight from our requests
- private var bytesInFlight = 0L
+ protected var bytesInFlight = 0L
protected def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
@@ -262,77 +270,58 @@ object BlockFetcherIterator {
readMetrics: ShuffleReadMetrics)
extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) {
- import blockManager._
+ override protected def sendRequest(req: FetchRequest) {
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
+ val cmId = new ConnectionManagerId(req.address.host, req.address.port)
- val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest]
+ bytesInFlight += req.size
+ val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
+
+ // This could throw a TimeoutException. In that case we will just retry the task.
+ val client = blockManager.nettyBlockClientFactory.createClient(
+ cmId.host, req.address.nettyPort)
+ val blocks = req.blocks.map(_._1.toString)
+
+ client.fetchBlocks(
+ blocks,
+ new BlockClientListener {
+ override def onFetchFailure(blockId: String, errorMsg: String): Unit = {
+ logError(s"Could not get block(s) from $cmId with error: $errorMsg")
+ for ((blockId, size) <- req.blocks) {
+ results.put(new FetchResult(blockId, -1, null))
+ }
+ }
- private def startCopiers(numCopiers: Int): List[_ <: Thread] = {
- (for ( i <- Range(0,numCopiers) ) yield {
- val copier = new Thread {
- override def run(){
- try {
- while(!isInterrupted && !fetchRequestsSync.isEmpty) {
- sendRequest(fetchRequestsSync.take())
+ override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = {
+ // Increment the reference count so the buffer won't be recycled.
+ // TODO: This could result in memory leaks when the task is stopped due to exception
+ // before the iterator is exhausted.
+ data.retain()
+ val buf = data.byteBuffer()
+ val blockSize = buf.remaining()
+ val bid = BlockId(blockId)
+
+ // TODO: remove code duplication between here and BlockManager.dataDeserialization.
+ results.put(new FetchResult(bid, sizeMap(bid), () => {
+ def createIterator: Iterator[Any] = {
+ val stream = blockManager.wrapForCompression(bid, data.inputStream())
+ serializer.newInstance().deserializeStream(stream).asIterator
}
- } catch {
- case x: InterruptedException => logInfo("Copier Interrupted")
- // case _ => throw new SparkException("Exception Throw in Shuffle Copier")
+ new LazyInitIterator(createIterator) {
+ // Release the buffer when we are done traversing it.
+ override def close(): Unit = data.release()
+ }
+ }))
+
+ readMetrics.synchronized {
+ readMetrics.remoteBytesRead += blockSize
+ readMetrics.remoteBlocksFetched += 1
}
+ logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
}
- copier.start
- copier
- }).toList
- }
-
- // keep this to interrupt the threads when necessary
- private def stopCopiers() {
- for (copier <- copiers) {
- copier.interrupt()
- }
- }
-
- override protected def sendRequest(req: FetchRequest) {
-
- def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) {
- val fetchResult = new FetchResult(blockId, blockSize,
- () => dataDeserialize(blockId, blockData.nioBuffer, serializer))
- results.put(fetchResult)
- }
-
- logDebug("Sending request for %d blocks (%s) from %s".format(
- req.blocks.size, Utils.bytesToString(req.size), req.address.host))
- val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort)
- val cpier = new ShuffleCopier(blockManager.conf)
- cpier.getBlocks(cmId, req.blocks, putResult)
- logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
- }
-
- private var copiers: List[_ <: Thread] = null
-
- override def initialize() {
- // Split Local Remote Blocks and set numBlocksToFetch
- val remoteRequests = splitLocalRemoteBlocks()
- // Add the remote requests into our queue in a random order
- for (request <- Utils.randomize(remoteRequests)) {
- fetchRequestsSync.put(request)
- }
-
- copiers = startCopiers(conf.getInt("spark.shuffle.copier.threads", 6))
- logInfo("Started " + fetchRequestsSync.size + " remote fetches in " +
- Utils.getUsedTimeMs(startTime))
-
- // Get Local Blocks
- startTime = System.currentTimeMillis
- getLocalBlocks()
- logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
- }
-
- override def next(): (BlockId, Option[Iterator[Any]]) = {
- resultsGotten += 1
- val result = results.take()
- // If all the results has been retrieved, copiers will exit automatically
- (result.blockId, if (result.failed) None else Some(result.deserialize()))
+ )
}
}
// End of NettyBlockFetcherIterator
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index e4c3d58905..c0491fb55e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -25,17 +25,20 @@ import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
import scala.util.Random
-import akka.actor.{ActorSystem, Cancellable, Props}
+import akka.actor.{ActorSystem, Props}
import sun.nio.ch.DirectBuffer
import org.apache.spark._
import org.apache.spark.executor._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
+import org.apache.spark.network.netty.client.BlockFetchingClientFactory
+import org.apache.spark.network.netty.server.BlockServer
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.util._
+
private[spark] sealed trait BlockValues
private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues
private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues
@@ -60,7 +63,7 @@ private[spark] class BlockManager(
securityManager: SecurityManager,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager)
- extends Logging {
+ extends BlockDataProvider with Logging {
private val port = conf.getInt("spark.blockManager.port", 0)
val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager)
@@ -88,13 +91,25 @@ private[spark] class BlockManager(
new TachyonStore(this, tachyonBlockManager)
}
+ private val useNetty = conf.getBoolean("spark.shuffle.use.netty", false)
+
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
- private val nettyPort: Int = {
- val useNetty = conf.getBoolean("spark.shuffle.use.netty", false)
- val nettyPortConfig = conf.getInt("spark.shuffle.sender.port", 0)
- if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
+ private[storage] val nettyBlockClientFactory: BlockFetchingClientFactory = {
+ if (useNetty) new BlockFetchingClientFactory(conf) else null
}
+ private val nettyBlockServer: BlockServer = {
+ if (useNetty) {
+ val server = new BlockServer(conf, this)
+ logInfo(s"Created NettyBlockServer binding to port: ${server.port}")
+ server
+ } else {
+ null
+ }
+ }
+
+ private val nettyPort: Int = if (useNetty) nettyBlockServer.port else 0
+
val blockManagerId = BlockManagerId(
executorId, connectionManager.id.host, connectionManager.id.port, nettyPort)
@@ -219,6 +234,20 @@ private[spark] class BlockManager(
}
}
+ override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
+ val bid = BlockId(blockId)
+ if (bid.isShuffle) {
+ Left(diskBlockManager.getBlockLocation(bid))
+ } else {
+ val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
+ if (blockBytesOpt.isDefined) {
+ Right(blockBytesOpt.get)
+ } else {
+ throw new BlockNotFoundException(blockId)
+ }
+ }
+ }
+
/**
* Get the BlockStatus for the block identified by the given ID, if it exists.
* NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon.
@@ -1064,6 +1093,14 @@ private[spark] class BlockManager(
connectionManager.stop()
shuffleBlockManager.stop()
diskBlockManager.stop()
+
+ if (nettyBlockClientFactory != null) {
+ nettyBlockClientFactory.stop()
+ }
+ if (nettyBlockServer != null) {
+ nettyBlockServer.stop()
+ }
+
actorSystem.stop(slaveActor)
blockInfo.clear()
memoryStore.clear()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
new file mode 100644
index 0000000000..9ef453605f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+
+class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found")
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 4d66ccea21..f3da816389 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -23,7 +23,7 @@ import java.util.{Date, Random, UUID}
import org.apache.spark.{SparkEnv, Logging}
import org.apache.spark.executor.ExecutorExitCode
-import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
+import org.apache.spark.network.netty.PathResolver
import org.apache.spark.util.Utils
import org.apache.spark.shuffle.sort.SortShuffleManager
@@ -52,7 +52,6 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager,
System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
}
private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
- private var shuffleSender : ShuffleSender = null
addShutdownHook()
@@ -186,15 +185,5 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager,
}
}
}
-
- if (shuffleSender != null) {
- shuffleSender.stop()
- }
- }
-
- private[storage] def startShuffleBlockSender(port: Int): Int = {
- shuffleSender = new ShuffleSender(port, this)
- logInfo(s"Created ShuffleSender binding to port: ${shuffleSender.port}")
- shuffleSender.port
}
}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
new file mode 100644
index 0000000000..02d0ffc86f
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.io.{RandomAccessFile, File}
+import java.nio.ByteBuffer
+import java.util.{Collections, HashSet}
+import java.util.concurrent.{TimeUnit, Semaphore}
+
+import scala.collection.JavaConversions._
+
+import io.netty.buffer.{ByteBufUtil, Unpooled}
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network.netty.client.{BlockClientListener, ReferenceCountedBuffer, BlockFetchingClientFactory}
+import org.apache.spark.network.netty.server.BlockServer
+import org.apache.spark.storage.{FileSegment, BlockDataProvider}
+
+
+/**
+ * Test suite that makes sure the server and the client implementations share the same protocol.
+ */
+class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
+
+ val bufSize = 100000
+ var buf: ByteBuffer = _
+ var testFile: File = _
+ var server: BlockServer = _
+ var clientFactory: BlockFetchingClientFactory = _
+
+ val bufferBlockId = "buffer_block"
+ val fileBlockId = "file_block"
+
+ val fileContent = new Array[Byte](1024)
+ scala.util.Random.nextBytes(fileContent)
+
+ override def beforeAll() = {
+ buf = ByteBuffer.allocate(bufSize)
+ for (i <- 1 to bufSize) {
+ buf.put(i.toByte)
+ }
+ buf.flip()
+
+ testFile = File.createTempFile("netty-test-file", "txt")
+ val fp = new RandomAccessFile(testFile, "rw")
+ fp.write(fileContent)
+ fp.close()
+
+ server = new BlockServer(new SparkConf, new BlockDataProvider {
+ override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
+ if (blockId == bufferBlockId) {
+ Right(buf)
+ } else if (blockId == fileBlockId) {
+ Left(new FileSegment(testFile, 10, testFile.length - 25))
+ } else {
+ throw new Exception("Unknown block id " + blockId)
+ }
+ }
+ })
+
+ clientFactory = new BlockFetchingClientFactory(new SparkConf)
+ }
+
+ override def afterAll() = {
+ server.stop()
+ clientFactory.stop()
+ }
+
+ /** A ByteBuf for buffer_block */
+ lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf)
+
+ /** A ByteBuf for file_block */
+ lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25)
+
+ def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String]) =
+ {
+ val client = clientFactory.createClient(server.hostName, server.port)
+ val sem = new Semaphore(0)
+ val receivedBlockIds = Collections.synchronizedSet(new HashSet[String])
+ val errorBlockIds = Collections.synchronizedSet(new HashSet[String])
+ val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer])
+
+ client.fetchBlocks(
+ blockIds,
+ new BlockClientListener {
+ override def onFetchFailure(blockId: String, errorMsg: String): Unit = {
+ errorBlockIds.add(blockId)
+ sem.release()
+ }
+
+ override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = {
+ receivedBlockIds.add(blockId)
+ data.retain()
+ receivedBuffers.add(data)
+ sem.release()
+ }
+ }
+ )
+ if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server")
+ }
+ client.close()
+ (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet)
+ }
+
+ test("fetch a ByteBuffer block") {
+ val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId))
+ assert(blockIds === Set(bufferBlockId))
+ assert(buffers.map(_.underlying) === Set(byteBufferBlockReference))
+ assert(failBlockIds.isEmpty)
+ buffers.foreach(_.release())
+ }
+
+ test("fetch a FileSegment block via zero-copy send") {
+ val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId))
+ assert(blockIds === Set(fileBlockId))
+ assert(buffers.map(_.underlying) === Set(fileBlockReference))
+ assert(failBlockIds.isEmpty)
+ buffers.foreach(_.release())
+ }
+
+ test("fetch a non-existent block") {
+ val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block"))
+ assert(blockIds.isEmpty)
+ assert(buffers.isEmpty)
+ assert(failBlockIds === Set("random-block"))
+ }
+
+ test("fetch both ByteBuffer block and FileSegment block") {
+ val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId))
+ assert(blockIds === Set(bufferBlockId, fileBlockId))
+ assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference))
+ assert(failBlockIds.isEmpty)
+ buffers.foreach(_.release())
+ }
+
+ test("fetch both ByteBuffer block and a non-existent block") {
+ val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block"))
+ assert(blockIds === Set(bufferBlockId))
+ assert(buffers.map(_.underlying) === Set(byteBufferBlockReference))
+ assert(failBlockIds === Set("random-block"))
+ buffers.foreach(_.release())
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala
new file mode 100644
index 0000000000..903ab09ae4
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.client
+
+import java.nio.ByteBuffer
+
+import io.netty.buffer.Unpooled
+import io.netty.channel.embedded.EmbeddedChannel
+
+import org.scalatest.{PrivateMethodTester, FunSuite}
+
+
+class BlockFetchingClientHandlerSuite extends FunSuite with PrivateMethodTester {
+
+ test("handling block data (successful fetch)") {
+ val blockId = "test_block"
+ val blockData = "blahblahblahblahblah"
+ val totalLength = 4 + blockId.length + blockData.length
+
+ var parsedBlockId: String = ""
+ var parsedBlockData: String = ""
+ val handler = new BlockFetchingClientHandler
+ handler.addRequest(blockId,
+ new BlockClientListener {
+ override def onFetchFailure(blockId: String, errorMsg: String): Unit = ???
+ override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer): Unit = {
+ parsedBlockId = bid
+ val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining)
+ refCntBuf.byteBuffer().get(bytes)
+ parsedBlockData = new String(bytes)
+ }
+ }
+ )
+
+ val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests)
+ assert(handler.invokePrivate(outstandingRequests()).size === 1)
+
+ val channel = new EmbeddedChannel(handler)
+ val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
+ buf.putInt(totalLength)
+ buf.putInt(blockId.length)
+ buf.put(blockId.getBytes)
+ buf.put(blockData.getBytes)
+ buf.flip()
+
+ channel.writeInbound(Unpooled.wrappedBuffer(buf))
+ assert(parsedBlockId === blockId)
+ assert(parsedBlockData === blockData)
+
+ assert(handler.invokePrivate(outstandingRequests()).size === 0)
+
+ channel.close()
+ }
+
+ test("handling error message (failed fetch)") {
+ val blockId = "test_block"
+ val errorMsg = "error erro5r error err4or error3 error6 error erro1r"
+ val totalLength = 4 + blockId.length + errorMsg.length
+
+ var parsedBlockId: String = ""
+ var parsedErrorMsg: String = ""
+ val handler = new BlockFetchingClientHandler
+ handler.addRequest(blockId, new BlockClientListener {
+ override def onFetchFailure(bid: String, msg: String) ={
+ parsedBlockId = bid
+ parsedErrorMsg = msg
+ }
+ override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer) = ???
+ })
+
+ val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests)
+ assert(handler.invokePrivate(outstandingRequests()).size === 1)
+
+ val channel = new EmbeddedChannel(handler)
+ val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
+ buf.putInt(totalLength)
+ buf.putInt(-blockId.length)
+ buf.put(blockId.getBytes)
+ buf.put(errorMsg.getBytes)
+ buf.flip()
+
+ channel.writeInbound(Unpooled.wrappedBuffer(buf))
+ assert(parsedBlockId === blockId)
+ assert(parsedErrorMsg === errorMsg)
+
+ assert(handler.invokePrivate(outstandingRequests()).size === 0)
+
+ channel.close()
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala
new file mode 100644
index 0000000000..3ee281cb13
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.server
+
+import io.netty.buffer.ByteBuf
+import io.netty.channel.embedded.EmbeddedChannel
+
+import org.scalatest.FunSuite
+
+
+class BlockHeaderEncoderSuite extends FunSuite {
+
+ test("encode normal block data") {
+ val blockId = "test_block"
+ val channel = new EmbeddedChannel(new BlockHeaderEncoder)
+ channel.writeOutbound(new BlockHeader(17, blockId, None))
+ val out = channel.readOutbound().asInstanceOf[ByteBuf]
+ assert(out.readInt() === 4 + blockId.length + 17)
+ assert(out.readInt() === blockId.length)
+
+ val blockIdBytes = new Array[Byte](blockId.length)
+ out.readBytes(blockIdBytes)
+ assert(new String(blockIdBytes) === blockId)
+ assert(out.readableBytes() === 0)
+
+ channel.close()
+ }
+
+ test("encode error message") {
+ val blockId = "error_block"
+ val errorMsg = "error encountered"
+ val channel = new EmbeddedChannel(new BlockHeaderEncoder)
+ channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg)))
+ val out = channel.readOutbound().asInstanceOf[ByteBuf]
+ assert(out.readInt() === 4 + blockId.length + errorMsg.length)
+ assert(out.readInt() === -blockId.length)
+
+ val blockIdBytes = new Array[Byte](blockId.length)
+ out.readBytes(blockIdBytes)
+ assert(new String(blockIdBytes) === blockId)
+
+ val errorMsgBytes = new Array[Byte](errorMsg.length)
+ out.readBytes(errorMsgBytes)
+ assert(new String(errorMsgBytes) === errorMsg)
+ assert(out.readableBytes() === 0)
+
+ channel.close()
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala
new file mode 100644
index 0000000000..3239c710f1
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.server
+
+import java.io.{RandomAccessFile, File}
+import java.nio.ByteBuffer
+
+import io.netty.buffer.{Unpooled, ByteBuf}
+import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler, DefaultFileRegion}
+import io.netty.channel.embedded.EmbeddedChannel
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.storage.{BlockDataProvider, FileSegment}
+
+
+class BlockServerHandlerSuite extends FunSuite {
+
+ test("ByteBuffer block") {
+ val expectedBlockId = "test_bytebuffer_block"
+ val buf = ByteBuffer.allocate(10000)
+ for (i <- 1 to 10000) {
+ buf.put(i.toByte)
+ }
+ buf.flip()
+
+ val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider {
+ override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = Right(buf)
+ }))
+
+ channel.writeInbound(expectedBlockId)
+ assert(channel.outboundMessages().size === 2)
+
+ val out1 = channel.readOutbound().asInstanceOf[BlockHeader]
+ val out2 = channel.readOutbound().asInstanceOf[ByteBuf]
+
+ assert(out1.blockId === expectedBlockId)
+ assert(out1.blockSize === buf.remaining)
+ assert(out1.error === None)
+
+ assert(out2.equals(Unpooled.wrappedBuffer(buf)))
+
+ channel.close()
+ }
+
+ test("FileSegment block via zero-copy") {
+ val expectedBlockId = "test_file_block"
+
+ // Create random file data
+ val fileContent = new Array[Byte](1024)
+ scala.util.Random.nextBytes(fileContent)
+ val testFile = File.createTempFile("netty-test-file", "txt")
+ val fp = new RandomAccessFile(testFile, "rw")
+ fp.write(fileContent)
+ fp.close()
+
+ val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider {
+ override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
+ Left(new FileSegment(testFile, 15, testFile.length - 25))
+ }
+ }))
+
+ channel.writeInbound(expectedBlockId)
+ assert(channel.outboundMessages().size === 2)
+
+ val out1 = channel.readOutbound().asInstanceOf[BlockHeader]
+ val out2 = channel.readOutbound().asInstanceOf[DefaultFileRegion]
+
+ assert(out1.blockId === expectedBlockId)
+ assert(out1.blockSize === testFile.length - 25)
+ assert(out1.error === None)
+
+ assert(out2.count === testFile.length - 25)
+ assert(out2.position === 15)
+ }
+
+ test("pipeline exception propagation") {
+ val blockServerHandler = new BlockServerHandler(new BlockDataProvider {
+ override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = ???
+ })
+ val exceptionHandler = new SimpleChannelInboundHandler[String]() {
+ override def channelRead0(ctx: ChannelHandlerContext, msg: String): Unit = {
+ throw new Exception("this is an error")
+ }
+ }
+
+ val channel = new EmbeddedChannel(exceptionHandler, blockServerHandler)
+ assert(channel.isOpen)
+ channel.writeInbound("a message to trigger the error")
+ assert(!channel.isOpen)
+ }
+}