aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/network/BlockDataManager.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/network/BlockTransferService.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala166
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala95
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala76
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala111
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala59
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala132
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala99
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala104
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala47
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala47
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala162
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala40
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala140
-rw-r--r--core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala52
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala135
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala161
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala106
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala64
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala107
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala60
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala261
39 files changed, 697 insertions, 1831 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 5c076e5f1c..6a6dfda363 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -32,6 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.netty.{NettyBlockTransferService}
import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
@@ -272,7 +273,13 @@ object SparkEnv extends Logging {
val shuffleMemoryManager = new ShuffleMemoryManager(conf)
- val blockTransferService = new NioBlockTransferService(conf, securityManager)
+ val blockTransferService =
+ conf.get("spark.shuffle.blockTransferService", "nio").toLowerCase match {
+ case "netty" =>
+ new NettyBlockTransferService(conf)
+ case "nio" =>
+ new NioBlockTransferService(conf, securityManager)
+ }
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
index e0e9172427..1745d52c81 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
@@ -17,20 +17,20 @@
package org.apache.spark.network
-import org.apache.spark.storage.StorageLevel
-
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.storage.{BlockId, StorageLevel}
+private[spark]
trait BlockDataManager {
/**
- * Interface to get local block data.
- *
- * @return Some(buffer) if the block exists locally, and None if it doesn't.
+ * Interface to get local block data. Throws an exception if the block cannot be found or
+ * cannot be read successfully.
*/
- def getBlockData(blockId: String): Option[ManagedBuffer]
+ def getBlockData(blockId: BlockId): ManagedBuffer
/**
* Put the block locally, using the given storage level.
*/
- def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit
+ def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
index 34acaa563c..645793fde8 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
@@ -19,19 +19,24 @@ package org.apache.spark.network
import java.util.EventListener
+import org.apache.spark.network.buffer.ManagedBuffer
+
/**
* Listener callback interface for [[BlockTransferService.fetchBlocks]].
*/
+private[spark]
trait BlockFetchingListener extends EventListener {
/**
- * Called once per successfully fetched block.
+ * Called once per successfully fetched block. After this call returns, data will be released
+ * automatically. If the data will be passed to another thread, the receiver should retain()
+ * and release() the buffer on their own, or copy the data to a new buffer.
*/
def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit
/**
- * Called upon failures. For each failure, this is called only once (i.e. not once per block).
+ * Called at least once per block upon failures.
*/
- def onBlockFetchFailure(exception: Throwable): Unit
+ def onBlockFetchFailure(blockId: String, exception: Throwable): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index 84d991fa68..b083f46533 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -17,13 +17,19 @@
package org.apache.spark.network
+import java.io.Closeable
+import java.nio.ByteBuffer
+
import scala.concurrent.{Await, Future}
import scala.concurrent.duration.Duration
-import org.apache.spark.storage.StorageLevel
-
+import org.apache.spark.Logging
+import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
+import org.apache.spark.storage.{BlockId, StorageLevel}
+import org.apache.spark.util.Utils
-abstract class BlockTransferService {
+private[spark]
+abstract class BlockTransferService extends Closeable with Logging {
/**
* Initialize the transfer service by giving it the BlockDataManager that can be used to fetch
@@ -34,7 +40,7 @@ abstract class BlockTransferService {
/**
* Tear down the transfer service.
*/
- def stop(): Unit
+ def close(): Unit
/**
* Port number the service is listening on, available only after [[init]] is invoked.
@@ -50,9 +56,6 @@ abstract class BlockTransferService {
* Fetch a sequence of blocks from a remote node asynchronously,
* available only after [[init]] is invoked.
*
- * Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block,
- * while [[BlockFetchingListener.onBlockFetchFailure]] is called once per failure (not per block).
- *
* Note that this API takes a sequence so the implementation can batch requests, and does not
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
* the data of a block is fetched, rather than waiting for all blocks to be fetched.
@@ -69,7 +72,7 @@ abstract class BlockTransferService {
def uploadBlock(
hostname: String,
port: Int,
- blockId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit]
@@ -83,7 +86,7 @@ abstract class BlockTransferService {
val lock = new Object
@volatile var result: Either[ManagedBuffer, Throwable] = null
fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener {
- override def onBlockFetchFailure(exception: Throwable): Unit = {
+ override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
lock.synchronized {
result = Right(exception)
lock.notify()
@@ -91,7 +94,10 @@ abstract class BlockTransferService {
}
override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
lock.synchronized {
- result = Left(data)
+ val ret = ByteBuffer.allocate(data.size.toInt)
+ ret.put(data.nioByteBuffer())
+ ret.flip()
+ result = Left(new NioManagedBuffer(ret))
lock.notify()
}
}
@@ -123,7 +129,7 @@ abstract class BlockTransferService {
def uploadBlockSync(
hostname: String,
port: Int,
- blockId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Unit = {
Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf)
diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
deleted file mode 100644
index 4211ba4e43..0000000000
--- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
+++ /dev/null
@@ -1,166 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.io._
-import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
-import java.nio.channels.FileChannel.MapMode
-
-import scala.util.Try
-
-import com.google.common.io.ByteStreams
-import io.netty.buffer.{ByteBufInputStream, ByteBuf}
-
-import org.apache.spark.util.{ByteBufferInputStream, Utils}
-
-
-/**
- * This interface provides an immutable view for data in the form of bytes. The implementation
- * should specify how the data is provided:
- *
- * - FileSegmentManagedBuffer: data backed by part of a file
- * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer
- * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf
- */
-sealed abstract class ManagedBuffer {
- // Note that all the methods are defined with parenthesis because their implementations can
- // have side effects (io operations).
-
- /** Number of bytes of the data. */
- def size: Long
-
- /**
- * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
- * returned ByteBuffer should not affect the content of this buffer.
- */
- def nioByteBuffer(): ByteBuffer
-
- /**
- * Exposes this buffer's data as an InputStream. The underlying implementation does not
- * necessarily check for the length of bytes read, so the caller is responsible for making sure
- * it does not go over the limit.
- */
- def inputStream(): InputStream
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by a segment in a file
- */
-final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long)
- extends ManagedBuffer {
-
- /**
- * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889).
- * Avoid unless there's a good reason not to.
- */
- private val MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024;
-
- override def size: Long = length
-
- override def nioByteBuffer(): ByteBuffer = {
- var channel: FileChannel = null
- try {
- channel = new RandomAccessFile(file, "r").getChannel
- // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead.
- if (length < MIN_MEMORY_MAP_BYTES) {
- val buf = ByteBuffer.allocate(length.toInt)
- channel.position(offset)
- while (buf.remaining() != 0) {
- if (channel.read(buf) == -1) {
- throw new IOException("Reached EOF before filling buffer\n" +
- s"offset=$offset\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}")
- }
- }
- buf.flip()
- buf
- } else {
- channel.map(MapMode.READ_ONLY, offset, length)
- }
- } catch {
- case e: IOException =>
- Try(channel.size).toOption match {
- case Some(fileLen) =>
- throw new IOException(s"Error in reading $this (actual file length $fileLen)", e)
- case None =>
- throw new IOException(s"Error in opening $this", e)
- }
- } finally {
- if (channel != null) {
- Utils.tryLog(channel.close())
- }
- }
- }
-
- override def inputStream(): InputStream = {
- var is: FileInputStream = null
- try {
- is = new FileInputStream(file)
- ByteStreams.skipFully(is, offset)
- ByteStreams.limit(is, length)
- } catch {
- case e: IOException =>
- if (is != null) {
- Utils.tryLog(is.close())
- }
- Try(file.length).toOption match {
- case Some(fileLen) =>
- throw new IOException(s"Error in reading $this (actual file length $fileLen)", e)
- case None =>
- throw new IOException(s"Error in opening $this", e)
- }
- case e: Throwable =>
- if (is != null) {
- Utils.tryLog(is.close())
- }
- throw e
- }
- }
-
- override def toString: String = s"${getClass.getName}($file, $offset, $length)"
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]].
- */
-final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer {
-
- override def size: Long = buf.remaining()
-
- override def nioByteBuffer() = buf.duplicate()
-
- override def inputStream() = new ByteBufferInputStream(buf)
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]].
- */
-final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer {
-
- override def size: Long = buf.readableBytes()
-
- override def nioByteBuffer() = buf.nioBuffer()
-
- override def inputStream() = new ByteBufInputStream(buf)
-
- // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it.
- def release(): Unit = buf.release()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala
new file mode 100644
index 0000000000..8c5ffd8da6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.nio.ByteBuffer
+import java.util
+
+import org.apache.spark.{SparkConf, Logging}
+import org.apache.spark.network.BlockFetchingListener
+import org.apache.spark.network.netty.NettyMessages._
+import org.apache.spark.serializer.{JavaSerializer, Serializer}
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, TransportClient}
+import org.apache.spark.storage.BlockId
+import org.apache.spark.util.Utils
+
+/**
+ * Responsible for holding the state for a request for a single set of blocks. This assumes that
+ * the chunks will be returned in the same order as requested, and that there will be exactly
+ * one chunk per block.
+ *
+ * Upon receipt of any block, the listener will be called back. Upon failure part way through,
+ * the listener will receive a failure callback for each outstanding block.
+ */
+class NettyBlockFetcher(
+ serializer: Serializer,
+ client: TransportClient,
+ blockIds: Seq[String],
+ listener: BlockFetchingListener)
+ extends Logging {
+
+ require(blockIds.nonEmpty)
+
+ private val ser = serializer.newInstance()
+
+ private var streamHandle: ShuffleStreamHandle = _
+
+ private val chunkCallback = new ChunkReceivedCallback {
+ // On receipt of a chunk, pass it upwards as a block.
+ def onSuccess(chunkIndex: Int, buffer: ManagedBuffer): Unit = Utils.logUncaughtExceptions {
+ listener.onBlockFetchSuccess(blockIds(chunkIndex), buffer)
+ }
+
+ // On receipt of a failure, fail every block from chunkIndex onwards.
+ def onFailure(chunkIndex: Int, e: Throwable): Unit = {
+ blockIds.drop(chunkIndex).foreach { blockId =>
+ listener.onBlockFetchFailure(blockId, e);
+ }
+ }
+ }
+
+ /** Begins the fetching process, calling the listener with every block fetched. */
+ def start(): Unit = {
+ // Send the RPC to open the given set of blocks. This will return a ShuffleStreamHandle.
+ client.sendRpc(ser.serialize(OpenBlocks(blockIds.map(BlockId.apply))).array(),
+ new RpcResponseCallback {
+ override def onSuccess(response: Array[Byte]): Unit = {
+ try {
+ streamHandle = ser.deserialize[ShuffleStreamHandle](ByteBuffer.wrap(response))
+ logTrace(s"Successfully opened block set: $streamHandle! Preparing to fetch chunks.")
+
+ // Immediately request all chunks -- we expect that the total size of the request is
+ // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
+ for (i <- 0 until streamHandle.numChunks) {
+ client.fetchChunk(streamHandle.streamId, i, chunkCallback)
+ }
+ } catch {
+ case e: Exception =>
+ logError("Failed while starting block fetches", e)
+ blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e)))
+ }
+ }
+
+ override def onFailure(e: Throwable): Unit = {
+ logError("Failed while starting block fetches", e)
+ blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e)))
+ }
+ })
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
new file mode 100644
index 0000000000..02c657e1d6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.Logging
+import org.apache.spark.network.BlockDataManager
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.client.{TransportClient, RpcResponseCallback}
+import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler}
+import org.apache.spark.storage.{StorageLevel, BlockId}
+
+import scala.collection.JavaConversions._
+
+object NettyMessages {
+
+ /** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */
+ case class OpenBlocks(blockIds: Seq[BlockId])
+
+ /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
+ case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel)
+
+ /** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */
+ case class ShuffleStreamHandle(streamId: Long, numChunks: Int)
+}
+
+/**
+ * Serves requests to open blocks by simply registering one chunk per block requested.
+ */
+class NettyBlockRpcServer(
+ serializer: Serializer,
+ streamManager: DefaultStreamManager,
+ blockManager: BlockDataManager)
+ extends RpcHandler with Logging {
+
+ import NettyMessages._
+
+ override def receive(
+ client: TransportClient,
+ messageBytes: Array[Byte],
+ responseContext: RpcResponseCallback): Unit = {
+ val ser = serializer.newInstance()
+ val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes))
+ logTrace(s"Received request: $message")
+
+ message match {
+ case OpenBlocks(blockIds) =>
+ val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData)
+ val streamId = streamManager.registerStream(blocks.iterator)
+ logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
+ responseContext.onSuccess(
+ ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array())
+
+ case UploadBlock(blockId, blockData, level) =>
+ blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level)
+ responseContext.onSuccess(new Array[Byte](0))
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
new file mode 100644
index 0000000000..38a3e94515
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import scala.concurrent.{Promise, Future}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.client.{RpcResponseCallback, TransportClient, TransportClientFactory}
+import org.apache.spark.network.netty.NettyMessages.UploadBlock
+import org.apache.spark.network.server._
+import org.apache.spark.network.util.{ConfigProvider, TransportConf}
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.storage.{BlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+/**
+ * A BlockTransferService that uses Netty to fetch a set of blocks at at time.
+ */
+class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
+ // TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
+ val serializer = new JavaSerializer(conf)
+
+ // Create a TransportConfig using SparkConf.
+ private[this] val transportConf = new TransportConf(
+ new ConfigProvider { override def get(name: String) = conf.get(name) })
+
+ private[this] var transportContext: TransportContext = _
+ private[this] var server: TransportServer = _
+ private[this] var clientFactory: TransportClientFactory = _
+
+ override def init(blockDataManager: BlockDataManager): Unit = {
+ val streamManager = new DefaultStreamManager
+ val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager)
+ transportContext = new TransportContext(transportConf, streamManager, rpcHandler)
+ clientFactory = transportContext.createClientFactory()
+ server = transportContext.createServer()
+ }
+
+ override def fetchBlocks(
+ hostname: String,
+ port: Int,
+ blockIds: Seq[String],
+ listener: BlockFetchingListener): Unit = {
+ try {
+ val client = clientFactory.createClient(hostname, port)
+ new NettyBlockFetcher(serializer, client, blockIds, listener).start()
+ } catch {
+ case e: Exception =>
+ logError("Exception while beginning fetchBlocks", e)
+ blockIds.foreach(listener.onBlockFetchFailure(_, e))
+ }
+ }
+
+ override def hostName: String = Utils.localHostName()
+
+ override def port: Int = server.getPort
+
+ override def uploadBlock(
+ hostname: String,
+ port: Int,
+ blockId: BlockId,
+ blockData: ManagedBuffer,
+ level: StorageLevel): Future[Unit] = {
+ val result = Promise[Unit]()
+ val client = clientFactory.createClient(hostname, port)
+
+ // Convert or copy nio buffer into array in order to serialize it.
+ val nioBuffer = blockData.nioByteBuffer()
+ val array = if (nioBuffer.hasArray) {
+ nioBuffer.array()
+ } else {
+ val data = new Array[Byte](nioBuffer.remaining())
+ nioBuffer.get(data)
+ data
+ }
+
+ val ser = serializer.newInstance()
+ client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(),
+ new RpcResponseCallback {
+ override def onSuccess(response: Array[Byte]): Unit = {
+ logTrace(s"Successfully uploaded block $blockId")
+ result.success()
+ }
+ override def onFailure(e: Throwable): Unit = {
+ logError(s"Error while uploading block $blockId", e)
+ result.failure(e)
+ }
+ })
+
+ result.future
+ }
+
+ override def close(): Unit = server.close()
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala
deleted file mode 100644
index b5870152c5..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import org.apache.spark.SparkConf
-
-/**
- * A central location that tracks all the settings we exposed to users.
- */
-private[spark]
-class NettyConfig(conf: SparkConf) {
-
- /** Port the server listens on. Default to a random port. */
- private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0)
-
- /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */
- private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase
-
- /** Connect timeout in secs. Default 60 secs. */
- private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000
-
- /**
- * Percentage of the desired amount of time spent for I/O in the child event loops.
- * Only applicable in nio and epoll.
- */
- private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80)
-
- /** Requested maximum length of the queue of incoming connections. */
- private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt)
-
- /**
- * Receive buffer size (SO_RCVBUF).
- * Note: the optimal size for receive buffer and send buffer should be
- * latency * network_bandwidth.
- * Assuming latency = 1ms, network_bandwidth = 10Gbps
- * buffer size should be ~ 1.25MB
- */
- private[netty] val receiveBuf: Option[Int] =
- conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
-
- /** Send buffer size (SO_SNDBUF). */
- private[netty] val sendBuf: Option[Int] =
- conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala
deleted file mode 100644
index 0d7695072a..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala
+++ /dev/null
@@ -1,25 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import org.apache.spark.storage.{BlockId, FileSegment}
-
-trait PathResolver {
- /** Get the file segment in which the given block resides. */
- def getBlockLocation(blockId: BlockId): FileSegment
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala
deleted file mode 100644
index e28219dd77..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.util.EventListener
-
-
-trait BlockClientListener extends EventListener {
-
- def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit
-
- def onFetchFailure(blockId: String, errorMsg: String): Unit
-
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala
deleted file mode 100644
index 3ab13b96d7..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala
+++ /dev/null
@@ -1,132 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.util.concurrent.TimeoutException
-
-import com.google.common.base.Charsets.UTF_8
-import io.netty.bootstrap.Bootstrap
-import io.netty.buffer.PooledByteBufAllocator
-import io.netty.channel.socket.SocketChannel
-import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption}
-import io.netty.handler.codec.LengthFieldBasedFrameDecoder
-import io.netty.handler.codec.string.StringEncoder
-
-import org.apache.spark.Logging
-
-/**
- * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]].
- * Use [[BlockFetchingClientFactory]] to instantiate this client.
- *
- * The constructor blocks until a connection is successfully established.
- *
- * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol.
- *
- * Concurrency: thread safe and can be called from multiple threads.
- */
-@throws[TimeoutException]
-private[spark]
-class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int)
- extends Logging {
-
- private val handler = new BlockFetchingClientHandler
-
- /** Netty Bootstrap for creating the TCP connection. */
- private val bootstrap: Bootstrap = {
- val b = new Bootstrap
- b.group(factory.workerGroup)
- .channel(factory.socketChannelClass)
- // Use pooled buffers to reduce temporary buffer allocation
- .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
- // Disable Nagle's Algorithm since we don't want packets to wait
- .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
- .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
- .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs)
-
- b.handler(new ChannelInitializer[SocketChannel] {
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("encoder", new StringEncoder(UTF_8))
- // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4
- .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4))
- .addLast("handler", handler)
- }
- })
- b
- }
-
- /** Netty ChannelFuture for the connection. */
- private val cf: ChannelFuture = bootstrap.connect(hostname, port)
- if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) {
- throw new TimeoutException(
- s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)")
- }
-
- /**
- * Ask the remote server for a sequence of blocks, and execute the callback.
- *
- * Note that this is asynchronous and returns immediately. Upstream caller should throttle the
- * rate of fetching; otherwise we could run out of memory.
- *
- * @param blockIds sequence of block ids to fetch.
- * @param listener callback to fire on fetch success / failure.
- */
- def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = {
- // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline.
- // It's also best to limit the number of "flush" calls since it requires system calls.
- // Let's concatenate the string and then call writeAndFlush once.
- // This is also why this implementation might be more efficient than multiple, separate
- // fetch block calls.
- var startTime: Long = 0
- logTrace {
- startTime = System.nanoTime
- s"Sending request $blockIds to $hostname:$port"
- }
-
- blockIds.foreach { blockId =>
- handler.addRequest(blockId, listener)
- }
-
- val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n")
- writeFuture.addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture): Unit = {
- if (future.isSuccess) {
- logTrace {
- val timeTaken = (System.nanoTime - startTime).toDouble / 1000000
- s"Sending request $blockIds to $hostname:$port took $timeTaken ms"
- }
- } else {
- // Fail all blocks.
- val errorMsg =
- s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}"
- logError(errorMsg, future.cause)
- blockIds.foreach { blockId =>
- listener.onFetchFailure(blockId, errorMsg)
- handler.removeRequest(blockId)
- }
- }
- }
- })
- }
-
- def waitForClose(): Unit = {
- cf.channel().closeFuture().sync()
- }
-
- def close(): Unit = cf.channel().close()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala
deleted file mode 100644
index 2b28402c52..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel}
-import io.netty.channel.nio.NioEventLoopGroup
-import io.netty.channel.oio.OioEventLoopGroup
-import io.netty.channel.socket.nio.NioSocketChannel
-import io.netty.channel.socket.oio.OioSocketChannel
-import io.netty.channel.{EventLoopGroup, Channel}
-
-import org.apache.spark.SparkConf
-import org.apache.spark.network.netty.NettyConfig
-import org.apache.spark.util.Utils
-
-/**
- * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses
- * the worker thread pool for Netty.
- *
- * Concurrency: createClient is safe to be called from multiple threads concurrently.
- */
-private[spark]
-class BlockFetchingClientFactory(val conf: NettyConfig) {
-
- def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf))
-
- /** A thread factory so the threads are named (for debugging). */
- val threadFactory = Utils.namedThreadFactory("spark-shuffle-client")
-
- /** The following two are instantiated by the [[init]] method, depending ioMode. */
- var socketChannelClass: Class[_ <: Channel] = _
- var workerGroup: EventLoopGroup = _
-
- init()
-
- /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */
- private def init(): Unit = {
- def initOio(): Unit = {
- socketChannelClass = classOf[OioSocketChannel]
- workerGroup = new OioEventLoopGroup(0, threadFactory)
- }
- def initNio(): Unit = {
- socketChannelClass = classOf[NioSocketChannel]
- workerGroup = new NioEventLoopGroup(0, threadFactory)
- }
- def initEpoll(): Unit = {
- socketChannelClass = classOf[EpollSocketChannel]
- workerGroup = new EpollEventLoopGroup(0, threadFactory)
- }
-
- conf.ioMode match {
- case "nio" => initNio()
- case "oio" => initOio()
- case "epoll" => initEpoll()
- case "auto" =>
- // For auto mode, first try epoll (only available on Linux), then nio.
- try {
- initEpoll()
- } catch {
- // TODO: Should we log the throwable? But that always happen on non-Linux systems.
- // Perhaps the right thing to do is to check whether the system is Linux, and then only
- // call initEpoll on Linux.
- case e: Throwable => initNio()
- }
- }
- }
-
- /**
- * Create a new BlockFetchingClient connecting to the given remote host / port.
- *
- * This blocks until a connection is successfully established.
- *
- * Concurrency: This method is safe to call from multiple threads.
- */
- def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = {
- new BlockFetchingClient(this, remoteHost, remotePort)
- }
-
- def stop(): Unit = {
- if (workerGroup != null) {
- workerGroup.shutdownGracefully()
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala
deleted file mode 100644
index d9d3f7bef0..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala
+++ /dev/null
@@ -1,104 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import com.google.common.base.Charsets.UTF_8
-import io.netty.buffer.ByteBuf
-import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
-
-import org.apache.spark.Logging
-
-
-/**
- * Handler that processes server responses. It uses the protocol documented in
- * [[org.apache.spark.network.netty.server.BlockServer]].
- *
- * Concurrency: thread safe and can be called from multiple threads.
- */
-private[client]
-class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging {
-
- /** Tracks the list of outstanding requests and their listeners on success/failure. */
- private val outstandingRequests = java.util.Collections.synchronizedMap {
- new java.util.HashMap[String, BlockClientListener]
- }
-
- def addRequest(blockId: String, listener: BlockClientListener): Unit = {
- outstandingRequests.put(blockId, listener)
- }
-
- def removeRequest(blockId: String): Unit = {
- outstandingRequests.remove(blockId)
- }
-
- override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}"
- logError(errorMsg, cause)
-
- // Fire the failure callback for all outstanding blocks
- outstandingRequests.synchronized {
- val iter = outstandingRequests.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- entry.getValue.onFetchFailure(entry.getKey, errorMsg)
- }
- outstandingRequests.clear()
- }
-
- ctx.close()
- }
-
- override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) {
- val totalLen = in.readInt()
- val blockIdLen = in.readInt()
- val blockIdBytes = new Array[Byte](math.abs(blockIdLen))
- in.readBytes(blockIdBytes)
- val blockId = new String(blockIdBytes, UTF_8)
- val blockSize = totalLen - math.abs(blockIdLen) - 4
-
- def server = ctx.channel.remoteAddress.toString
-
- // blockIdLen is negative when it is an error message.
- if (blockIdLen < 0) {
- val errorMessageBytes = new Array[Byte](blockSize)
- in.readBytes(errorMessageBytes)
- val errorMsg = new String(errorMessageBytes, UTF_8)
- logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server")
-
- val listener = outstandingRequests.get(blockId)
- if (listener == null) {
- // Ignore callback
- logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
- } else {
- outstandingRequests.remove(blockId)
- listener.onFetchFailure(blockId, errorMsg)
- }
- } else {
- logTrace(s"Received block $blockId ($blockSize B) from $server")
-
- val listener = outstandingRequests.get(blockId)
- if (listener == null) {
- // Ignore callback
- logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
- } else {
- outstandingRequests.remove(blockId)
- listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in))
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala
deleted file mode 100644
index 9740ee64d1..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-/**
- * A simple iterator that lazily initializes the underlying iterator.
- *
- * The use case is that sometimes we might have many iterators open at the same time, and each of
- * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer).
- * This could lead to too many buffers open. If this iterator is used, we lazily initialize those
- * buffers.
- */
-private[spark]
-class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] {
-
- lazy val proxy = createIterator
-
- override def hasNext: Boolean = {
- val gotNext = proxy.hasNext
- if (!gotNext) {
- close()
- }
- gotNext
- }
-
- override def next(): Any = proxy.next()
-
- def close(): Unit = Unit
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala
deleted file mode 100644
index ea1abf5ecc..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.io.InputStream
-import java.nio.ByteBuffer
-
-import io.netty.buffer.{ByteBuf, ByteBufInputStream}
-
-
-/**
- * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty.
- * This is a Scala value class.
- *
- * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of
- * reference by the retain method and release method.
- */
-private[spark]
-class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal {
-
- /** Return the nio ByteBuffer view of the underlying buffer. */
- def byteBuffer(): ByteBuffer = underlying.nioBuffer
-
- /** Creates a new input stream that starts from the current position of the buffer. */
- def inputStream(): InputStream = new ByteBufInputStream(underlying)
-
- /** Increment the reference counter by one. */
- def retain(): Unit = underlying.retain()
-
- /** Decrement the reference counter by one and release the buffer if the ref count is 0. */
- def release(): Unit = underlying.release()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala
deleted file mode 100644
index 162e9cc682..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-/**
- * Header describing a block. This is used only in the server pipeline.
- *
- * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it.
- *
- * @param blockSize length of the block content, excluding the length itself.
- * If positive, this is the header for a block (not part of the header).
- * If negative, this is the header and content for an error message.
- * @param blockId block id
- * @param error some error message from reading the block
- */
-private[server]
-class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None)
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala
deleted file mode 100644
index 8e4dda4ef8..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import io.netty.buffer.ByteBuf
-import io.netty.channel.ChannelHandlerContext
-import io.netty.handler.codec.MessageToByteEncoder
-
-/**
- * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol.
- */
-private[server]
-class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] {
- override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = {
- // message = message length (4 bytes) + block id length (4 bytes) + block id + block data
- // message length = block id length (4 bytes) + size of block id + size of block data
- val blockIdBytes = msg.blockId.getBytes
- msg.error match {
- case Some(errorMsg) =>
- val errorBytes = errorMsg.getBytes
- out.writeInt(4 + blockIdBytes.length + errorBytes.size)
- out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors
- out.writeBytes(blockIdBytes) // next is blockId itself
- out.writeBytes(errorBytes) // error message
- case None =>
- out.writeInt(4 + blockIdBytes.length + msg.blockSize)
- out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length
- out.writeBytes(blockIdBytes) // next is blockId itself
- // msg of size blockSize will be written by ServerHandler
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala
deleted file mode 100644
index 9194c7ced3..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala
+++ /dev/null
@@ -1,162 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import java.net.InetSocketAddress
-
-import com.google.common.base.Charsets.UTF_8
-import io.netty.bootstrap.ServerBootstrap
-import io.netty.buffer.PooledByteBufAllocator
-import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption}
-import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel}
-import io.netty.channel.nio.NioEventLoopGroup
-import io.netty.channel.oio.OioEventLoopGroup
-import io.netty.channel.socket.SocketChannel
-import io.netty.channel.socket.nio.NioServerSocketChannel
-import io.netty.channel.socket.oio.OioServerSocketChannel
-import io.netty.handler.codec.LineBasedFrameDecoder
-import io.netty.handler.codec.string.StringDecoder
-
-import org.apache.spark.{Logging, SparkConf}
-import org.apache.spark.network.netty.NettyConfig
-import org.apache.spark.storage.BlockDataProvider
-import org.apache.spark.util.Utils
-
-
-/**
- * Server for serving Spark data blocks.
- * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]].
- *
- * Protocol for requesting blocks (client to server):
- * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n"
- *
- * Protocol for sending blocks (server to client):
- * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data.
- *
- * frame-length should not include the length of itself.
- * If block-id-length is negative, then this is an error message rather than block-data. The real
- * length is the absolute value of the frame-length.
- *
- */
-private[spark]
-class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging {
-
- def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = {
- this(new NettyConfig(sparkConf), dataProvider)
- }
-
- def port: Int = _port
-
- def hostName: String = _hostName
-
- private var _port: Int = conf.serverPort
- private var _hostName: String = ""
- private var bootstrap: ServerBootstrap = _
- private var channelFuture: ChannelFuture = _
-
- init()
-
- /** Initialize the server. */
- private def init(): Unit = {
- bootstrap = new ServerBootstrap
- val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss")
- val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker")
-
- // Use only one thread to accept connections, and 2 * num_cores for worker.
- def initNio(): Unit = {
- val bossGroup = new NioEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new NioEventLoopGroup(0, workerThreadFactory)
- workerGroup.setIoRatio(conf.ioRatio)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel])
- }
- def initOio(): Unit = {
- val bossGroup = new OioEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new OioEventLoopGroup(0, workerThreadFactory)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel])
- }
- def initEpoll(): Unit = {
- val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory)
- workerGroup.setIoRatio(conf.ioRatio)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel])
- }
-
- conf.ioMode match {
- case "nio" => initNio()
- case "oio" => initOio()
- case "epoll" => initEpoll()
- case "auto" =>
- // For auto mode, first try epoll (only available on Linux), then nio.
- try {
- initEpoll()
- } catch {
- // TODO: Should we log the throwable? But that always happen on non-Linux systems.
- // Perhaps the right thing to do is to check whether the system is Linux, and then only
- // call initEpoll on Linux.
- case e: Throwable => initNio()
- }
- }
-
- // Use pooled buffers to reduce temporary buffer allocation
- bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
- bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
-
- // Various (advanced) user-configured settings.
- conf.backLog.foreach { backLog =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog)
- }
- conf.receiveBuf.foreach { receiveBuf =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf)
- }
- conf.sendBuf.foreach { sendBuf =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf)
- }
-
- bootstrap.childHandler(new ChannelInitializer[SocketChannel] {
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
- .addLast("stringDecoder", new StringDecoder(UTF_8))
- .addLast("blockHeaderEncoder", new BlockHeaderEncoder)
- .addLast("handler", new BlockServerHandler(dataProvider))
- }
- })
-
- channelFuture = bootstrap.bind(new InetSocketAddress(_port))
- channelFuture.sync()
-
- val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress]
- _port = addr.getPort
- _hostName = addr.getHostName
- }
-
- /** Shutdown the server. */
- def stop(): Unit = {
- if (channelFuture != null) {
- channelFuture.channel().close().awaitUninterruptibly()
- channelFuture = null
- }
- if (bootstrap != null && bootstrap.group() != null) {
- bootstrap.group().shutdownGracefully()
- }
- if (bootstrap != null && bootstrap.childGroup() != null) {
- bootstrap.childGroup().shutdownGracefully()
- }
- bootstrap = null
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala
deleted file mode 100644
index 188154d51d..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import com.google.common.base.Charsets.UTF_8
-import io.netty.channel.ChannelInitializer
-import io.netty.channel.socket.SocketChannel
-import io.netty.handler.codec.LineBasedFrameDecoder
-import io.netty.handler.codec.string.StringDecoder
-
-import org.apache.spark.storage.BlockDataProvider
-
-/** Channel initializer that sets up the pipeline for the BlockServer. */
-private[netty]
-class BlockServerChannelInitializer(dataProvider: BlockDataProvider)
- extends ChannelInitializer[SocketChannel] {
-
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
- .addLast("stringDecoder", new StringDecoder(UTF_8))
- .addLast("blockHeaderEncoder", new BlockHeaderEncoder)
- .addLast("handler", new BlockServerHandler(dataProvider))
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala
deleted file mode 100644
index 40dd5e5d1a..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala
+++ /dev/null
@@ -1,140 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import java.io.FileInputStream
-import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
-
-import io.netty.buffer.Unpooled
-import io.netty.channel._
-
-import org.apache.spark.Logging
-import org.apache.spark.storage.{FileSegment, BlockDataProvider}
-
-
-/**
- * A handler that processes requests from clients and writes block data back.
- *
- * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first
- * so channelRead0 is called once per line (i.e. per block id).
- */
-private[server]
-class BlockServerHandler(dataProvider: BlockDataProvider)
- extends SimpleChannelInboundHandler[String] with Logging {
-
- override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause)
- ctx.close()
- }
-
- override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = {
- def client = ctx.channel.remoteAddress.toString
-
- // A helper function to send error message back to the client.
- def respondWithError(error: String): Unit = {
- ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener(
- new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (!future.isSuccess) {
- // TODO: Maybe log the success case as well.
- logError(s"Error sending error back to $client", future.cause)
- ctx.close()
- }
- }
- }
- )
- }
-
- def writeFileSegment(segment: FileSegment): Unit = {
- // Send error message back if the block is too large. Even though we are capable of sending
- // large (2G+) blocks, the receiving end cannot handle it so let's fail fast.
- // Once we fixed the receiving end to be able to process large blocks, this should be removed.
- // Also make sure we update BlockHeaderEncoder to support length > 2G.
-
- // See [[BlockHeaderEncoder]] for the way length is encoded.
- if (segment.length + blockId.length + 4 > Int.MaxValue) {
- respondWithError(s"Block $blockId size ($segment.length) greater than 2G")
- return
- }
-
- var fileChannel: FileChannel = null
- try {
- fileChannel = new FileInputStream(segment.file).getChannel
- } catch {
- case e: Exception =>
- logError(
- s"Error opening channel for $blockId in ${segment.file} for request from $client", e)
- respondWithError(e.getMessage)
- }
-
- // Found the block. Send it back.
- if (fileChannel != null) {
- // Write the header and block data. In the case of failures, the listener on the block data
- // write should close the connection.
- ctx.write(new BlockHeader(segment.length.toInt, blockId))
-
- val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length)
- ctx.writeAndFlush(region).addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (future.isSuccess) {
- logTrace(s"Sent block $blockId (${segment.length} B) back to $client")
- } else {
- logError(s"Error sending block $blockId to $client; closing connection", future.cause)
- ctx.close()
- }
- }
- })
- }
- }
-
- def writeByteBuffer(buf: ByteBuffer): Unit = {
- ctx.write(new BlockHeader(buf.remaining, blockId))
- ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (future.isSuccess) {
- logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client")
- } else {
- logError(s"Error sending block $blockId to $client; closing connection", future.cause)
- ctx.close()
- }
- }
- })
- }
-
- logTrace(s"Received request from $client to fetch block $blockId")
-
- var blockData: Either[FileSegment, ByteBuffer] = null
-
- // First make sure we can find the block. If not, send error back to the user.
- try {
- blockData = dataProvider.getBlockData(blockId)
- } catch {
- case e: Exception =>
- logError(s"Error opening block $blockId for request from $client", e)
- respondWithError(e.getMessage)
- return
- }
-
- blockData match {
- case Left(segment) => writeFileSegment(segment)
- case Right(buf) => writeByteBuffer(buf)
- }
-
- } // end of channelRead0
-}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
index e3113205be..11793ea92a 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
@@ -19,12 +19,13 @@ package org.apache.spark.network.nio
import java.nio.ByteBuffer
-import scala.concurrent.Future
-
-import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf}
import org.apache.spark.network._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
+
+import scala.concurrent.Future
/**
@@ -71,7 +72,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
/**
* Tear down the transfer service.
*/
- override def stop(): Unit = {
+ override def close(): Unit = {
if (cm != null) {
cm.stop()
}
@@ -95,27 +96,34 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
future.onSuccess { case message =>
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
+
// SPARK-4064: In some cases(eg. Remote block was removed) blockMessageArray may be empty.
if (blockMessageArray.isEmpty) {
- listener.onBlockFetchFailure(
- new SparkException(s"Received empty message from $cmId"))
+ blockIds.foreach { id =>
+ listener.onBlockFetchFailure(id, new SparkException(s"Received empty message from $cmId"))
+ }
} else {
- for (blockMessage <- blockMessageArray) {
+ for (blockMessage: BlockMessage <- blockMessageArray) {
val msgType = blockMessage.getType
if (msgType != BlockMessage.TYPE_GOT_BLOCK) {
- listener.onBlockFetchFailure(
- new SparkException(s"Unexpected message ${msgType} received from $cmId"))
+ if (blockMessage.getId != null) {
+ listener.onBlockFetchFailure(blockMessage.getId.toString,
+ new SparkException(s"Unexpected message $msgType received from $cmId"))
+ }
} else {
val blockId = blockMessage.getId
+ val networkSize = blockMessage.getData.limit()
listener.onBlockFetchSuccess(
- blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData))
+ blockId.toString, new NioManagedBuffer(blockMessage.getData))
}
}
}
}(cm.futureExecContext)
future.onFailure { case exception =>
- listener.onBlockFetchFailure(exception)
+ blockIds.foreach { blockId =>
+ listener.onBlockFetchFailure(blockId, exception)
+ }
}(cm.futureExecContext)
}
@@ -127,12 +135,12 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
override def uploadBlock(
hostname: String,
port: Int,
- blockId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel)
: Future[Unit] = {
checkInit()
- val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level)
+ val msg = PutBlock(blockId, blockData.nioByteBuffer(), level)
val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg))
val remoteCmId = new ConnectionManagerId(hostName, port)
val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage)
@@ -154,10 +162,9 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
Some(new BlockMessageArray(responseMessages).toBufferMessage)
} catch {
- case e: Exception => {
+ case e: Exception =>
logError("Exception handling buffer message", e)
Some(Message.createErrorMessage(e, msg.id))
- }
}
case otherMessage: Any =>
@@ -172,13 +179,13 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
case BlockMessage.TYPE_PUT_BLOCK =>
val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
logDebug("Received [" + msg + "]")
- putBlock(msg.id.toString, msg.data, msg.level)
+ putBlock(msg.id, msg.data, msg.level)
None
case BlockMessage.TYPE_GET_BLOCK =>
val msg = new GetBlock(blockMessage.getId)
logDebug("Received [" + msg + "]")
- val buffer = getBlock(msg.id.toString)
+ val buffer = getBlock(msg.id)
if (buffer == null) {
return None
}
@@ -188,20 +195,20 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
}
}
- private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) {
val startTimeMs = System.currentTimeMillis()
logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes)
- blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level)
+ blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level)
logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " with data size: " + bytes.limit)
}
- private def getBlock(blockId: String): ByteBuffer = {
+ private def getBlock(blockId: BlockId): ByteBuffer = {
val startTimeMs = System.currentTimeMillis()
logDebug("GetBlock " + blockId + " started from " + startTimeMs)
- val buffer = blockDataManager.getBlockData(blockId).orNull
+ val buffer = blockDataManager.getBlockData(blockId)
logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " and got buffer " + buffer)
- if (buffer == null) null else buffer.nioByteBuffer()
+ buffer.nioByteBuffer()
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index a9144cdd97..ca6e971d22 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -17,14 +17,14 @@
package org.apache.spark.serializer
-import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream}
+import java.io._
import java.nio.ByteBuffer
import scala.reflect.ClassTag
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
+import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator}
/**
* :: DeveloperApi ::
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index 439981d232..1fb5b2c454 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -24,9 +24,9 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConversions._
-import org.apache.spark.{SparkEnv, SparkConf, Logging}
+import org.apache.spark.{Logging, SparkConf, SparkEnv}
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup
import org.apache.spark.storage._
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
index b5cd34cacd..e9805c9c13 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
@@ -23,7 +23,7 @@ import java.nio.ByteBuffer
import com.google.common.io.ByteStreams
import org.apache.spark.SparkEnv
-import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.storage._
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
index 63863cc025..b521f0c7fc 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
@@ -18,8 +18,7 @@
package org.apache.spark.shuffle
import java.nio.ByteBuffer
-
-import org.apache.spark.network.ManagedBuffer
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.storage.ShuffleBlockId
private[spark]
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala
deleted file mode 100644
index 5b6d086630..0000000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.nio.ByteBuffer
-
-
-/**
- * An interface for providing data for blocks.
- *
- * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer.
- *
- * Aside from unit tests, [[BlockManager]] is the main class that implements this.
- */
-private[spark] trait BlockDataProvider {
- def getBlockData(blockId: String): Either[FileSegment, ByteBuffer]
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 4cc9792365..58510d7232 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -17,15 +17,13 @@
package org.apache.spark.storage
-import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream}
+import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-import scala.concurrent.ExecutionContext.Implicits.global
-
-import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashMap}
-import scala.concurrent.{Await, Future}
+import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
+import scala.concurrent.{Await, Future}
import scala.util.Random
import akka.actor.{ActorSystem, Props}
@@ -35,11 +33,11 @@ import org.apache.spark._
import org.apache.spark.executor._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.util._
-
private[spark] sealed trait BlockValues
private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues
private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues
@@ -212,21 +210,20 @@ private[spark] class BlockManager(
}
/**
- * Interface to get local block data.
- *
- * @return Some(buffer) if the block exists locally, and None if it doesn't.
+ * Interface to get local block data. Throws an exception if the block cannot be found or
+ * cannot be read successfully.
*/
- override def getBlockData(blockId: String): Option[ManagedBuffer] = {
- val bid = BlockId(blockId)
- if (bid.isShuffle) {
- Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]))
+ override def getBlockData(blockId: BlockId): ManagedBuffer = {
+ if (blockId.isShuffle) {
+ shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
} else {
- val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
+ val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
+ .asInstanceOf[Option[ByteBuffer]]
if (blockBytesOpt.isDefined) {
val buffer = blockBytesOpt.get
- Some(new NioByteBufferManagedBuffer(buffer))
+ new NioManagedBuffer(buffer)
} else {
- None
+ throw new BlockNotFoundException(blockId.toString)
}
}
}
@@ -234,8 +231,8 @@ private[spark] class BlockManager(
/**
* Put the block locally, using the given storage level.
*/
- override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = {
- putBytes(BlockId(blockId), data.nioByteBuffer(), level)
+ override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = {
+ putBytes(blockId, data.nioByteBuffer(), level)
}
/**
@@ -341,17 +338,6 @@ private[spark] class BlockManager(
}
/**
- * A short-circuited method to get blocks directly from disk. This is used for getting
- * shuffle blocks. It is safe to do so without a lock on block info since disk store
- * never deletes (recent) items.
- */
- def getLocalShuffleFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
- val buf = shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
- val is = wrapForCompression(blockId, buf.inputStream())
- Some(serializer.newInstance().deserializeStream(is).asIterator)
- }
-
- /**
* Get block from local block manager.
*/
def getLocal(blockId: BlockId): Option[BlockResult] = {
@@ -869,9 +855,9 @@ private[spark] class BlockManager(
data.rewind()
logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
blockTransferService.uploadBlockSync(
- peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel)
- logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %d ms"
- .format((System.currentTimeMillis - onePeerStartTime)))
+ peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel)
+ logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms"
+ .format(System.currentTimeMillis - onePeerStartTime))
peersReplicatedTo += peer
peersForReplication -= peer
replicationFailed = false
@@ -1126,7 +1112,7 @@ private[spark] class BlockManager(
}
def stop(): Unit = {
- blockTransferService.stop()
+ blockTransferService.close()
diskBlockManager.stop()
actorSystem.stop(slaveActor)
blockInfo.clear()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
index 9ef453605f..81f5f2d31d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
@@ -17,5 +17,4 @@
package org.apache.spark.storage
-
class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found")
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 71b276b5f1..0d6f3bf003 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -19,15 +19,13 @@ package org.apache.spark.storage
import java.util.concurrent.LinkedBlockingQueue
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.Queue
+import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
-import org.apache.spark.{TaskContext, Logging}
-import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService}
+import org.apache.spark.{Logging, TaskContext}
+import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.Serializer
-import org.apache.spark.util.Utils
-
+import org.apache.spark.util.{CompletionIterator, Utils}
/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -88,17 +86,51 @@ final class ShuffleBlockFetcherIterator(
*/
private[this] val results = new LinkedBlockingQueue[FetchResult]
- // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
- // the number of bytes in flight is limited to maxBytesInFlight
+ /**
+ * Current [[FetchResult]] being processed. We track this so we can release the current buffer
+ * in case of a runtime exception when processing the current buffer.
+ */
+ private[this] var currentResult: FetchResult = null
+
+ /**
+ * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ * the number of bytes in flight is limited to maxBytesInFlight.
+ */
private[this] val fetchRequests = new Queue[FetchRequest]
- // Current bytes in flight from our requests
+ /** Current bytes in flight from our requests */
private[this] var bytesInFlight = 0L
private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
+ /**
+ * Whether the iterator is still active. If isZombie is true, the callback interface will no
+ * longer place fetched blocks into [[results]].
+ */
+ @volatile private[this] var isZombie = false
+
initialize()
+ /**
+ * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
+ */
+ private[this] def cleanup() {
+ isZombie = true
+ // Release the current buffer if necessary
+ if (currentResult != null && !currentResult.failed) {
+ currentResult.buf.release()
+ }
+
+ // Release buffers in the results queue
+ val iter = results.iterator()
+ while (iter.hasNext) {
+ val result = iter.next()
+ if (!result.failed) {
+ result.buf.release()
+ }
+ }
+ }
+
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
@@ -110,24 +142,23 @@ final class ShuffleBlockFetcherIterator(
blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds,
new BlockFetchingListener {
- override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
- results.put(new FetchResult(BlockId(blockId), sizeMap(blockId),
- () => serializer.newInstance().deserializeStream(
- blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator
- ))
- shuffleMetrics.remoteBytesRead += data.size
- shuffleMetrics.remoteBlocksFetched += 1
- logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
+ // Only add the buffer to results queue if the iterator is not zombie,
+ // i.e. cleanup() has not been called yet.
+ if (!isZombie) {
+ // Increment the ref count because we need to pass this to a different thread.
+ // This needs to be released after use.
+ buf.retain()
+ results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), buf))
+ shuffleMetrics.remoteBytesRead += buf.size
+ shuffleMetrics.remoteBlocksFetched += 1
+ }
+ logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
- override def onBlockFetchFailure(e: Throwable): Unit = {
+ override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
- // Note that there is a chance that some blocks have been fetched successfully, but we
- // still add them to the failed queue. This is fine because when the caller see a
- // FetchFailedException, it is going to fail the entire task anyway.
- for ((blockId, size) <- req.blocks) {
- results.put(new FetchResult(blockId, -1, null))
- }
+ results.put(new FetchResult(BlockId(blockId), -1, null))
}
}
)
@@ -138,7 +169,7 @@ final class ShuffleBlockFetcherIterator(
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
+ logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
@@ -185,26 +216,34 @@ final class ShuffleBlockFetcherIterator(
remoteRequests
}
+ /**
+ * Fetch the local blocks while we are fetching remote blocks. This is ok because
+ * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we
+ * track in-memory are the ManagedBuffer references themselves.
+ */
private[this] def fetchLocalBlocks() {
- // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
- // these all at once because they will just memory-map some files, so they won't consume
- // any memory that might exceed our maxBytesInFlight
- for (id <- localBlocks) {
+ val iter = localBlocks.iterator
+ while (iter.hasNext) {
+ val blockId = iter.next()
try {
+ val buf = blockManager.getBlockData(blockId)
shuffleMetrics.localBlocksFetched += 1
- results.put(new FetchResult(
- id, 0, () => blockManager.getLocalShuffleFromDisk(id, serializer).get))
- logDebug("Got local block " + id)
+ buf.retain()
+ results.put(new FetchResult(blockId, 0, buf))
} catch {
case e: Exception =>
+ // If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
- results.put(new FetchResult(id, -1, null))
+ results.put(new FetchResult(blockId, -1, null))
return
}
}
}
private[this] def initialize(): Unit = {
+ // Add a task completion callback (called in both success case and failure case) to cleanup.
+ context.addTaskCompletionListener(_ => cleanup())
+
// Split local and remote blocks.
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
@@ -229,7 +268,8 @@ final class ShuffleBlockFetcherIterator(
override def next(): (BlockId, Option[Iterator[Any]]) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
- val result = results.take()
+ currentResult = results.take()
+ val result = currentResult
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
if (!result.failed) {
@@ -240,7 +280,21 @@ final class ShuffleBlockFetcherIterator(
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
- (result.blockId, if (result.failed) None else Some(result.deserialize()))
+
+ val iteratorOpt: Option[Iterator[Any]] = if (result.failed) {
+ None
+ } else {
+ val is = blockManager.wrapForCompression(result.blockId, result.buf.createInputStream())
+ val iter = serializer.newInstance().deserializeStream(is).asIterator
+ Some(CompletionIterator[Any, Iterator[Any]](iter, {
+ // Once the iterator is exhausted, release the buffer and set currentResult to null
+ // so we don't release it again in cleanup.
+ currentResult = null
+ result.buf.release()
+ }))
+ }
+
+ (result.blockId, iteratorOpt)
}
}
@@ -254,7 +308,7 @@ object ShuffleBlockFetcherIterator {
* @param blocks Sequence of tuple, where the first element is the block id,
* and the second element is the estimated size, used to calculate bytesInFlight.
*/
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
+ case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) {
val size = blocks.map(_._2).sum
}
@@ -262,10 +316,11 @@ object ShuffleBlockFetcherIterator {
* Result of a fetch from a remote block. A failure is represented as size == -1.
* @param blockId block id
* @param size estimated size of the block, used to calculate bytesInFlight.
- * Note that this is NOT the exact bytes.
- * @param deserialize closure to return the result in the form of an Iterator.
+ * Note that this is NOT the exact bytes. -1 if failure is present.
+ * @param buf [[ManagedBuffer]] for the content. null is error.
*/
- class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
+ case class FetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) {
def failed: Boolean = size == -1
+ if (failed) assert(buf == null) else assert(buf != null)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 1e881da511..0daab91143 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -43,7 +43,6 @@ import org.json4s._
import tachyon.client.{TachyonFile,TachyonFS}
import org.apache.spark._
-import org.apache.spark.util.SparkUncaughtExceptionHandler
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
/** CallSite represents a place in user code. It can have a short and a long form. */
diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
index d7b2d2e1e3..840d8273cb 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
@@ -24,10 +24,10 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with Netty shuffle mode.
override def beforeAll() {
- System.setProperty("spark.shuffle.use.netty", "true")
+ System.setProperty("spark.shuffle.blockTransferService", "netty")
}
override def afterAll() {
- System.clearProperty("spark.shuffle.use.netty")
+ System.clearProperty("spark.shuffle.blockTransferService")
}
}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
deleted file mode 100644
index 02d0ffc86f..0000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
+++ /dev/null
@@ -1,161 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.io.{RandomAccessFile, File}
-import java.nio.ByteBuffer
-import java.util.{Collections, HashSet}
-import java.util.concurrent.{TimeUnit, Semaphore}
-
-import scala.collection.JavaConversions._
-
-import io.netty.buffer.{ByteBufUtil, Unpooled}
-
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
-
-import org.apache.spark.SparkConf
-import org.apache.spark.network.netty.client.{BlockClientListener, ReferenceCountedBuffer, BlockFetchingClientFactory}
-import org.apache.spark.network.netty.server.BlockServer
-import org.apache.spark.storage.{FileSegment, BlockDataProvider}
-
-
-/**
- * Test suite that makes sure the server and the client implementations share the same protocol.
- */
-class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
-
- val bufSize = 100000
- var buf: ByteBuffer = _
- var testFile: File = _
- var server: BlockServer = _
- var clientFactory: BlockFetchingClientFactory = _
-
- val bufferBlockId = "buffer_block"
- val fileBlockId = "file_block"
-
- val fileContent = new Array[Byte](1024)
- scala.util.Random.nextBytes(fileContent)
-
- override def beforeAll() = {
- buf = ByteBuffer.allocate(bufSize)
- for (i <- 1 to bufSize) {
- buf.put(i.toByte)
- }
- buf.flip()
-
- testFile = File.createTempFile("netty-test-file", "txt")
- val fp = new RandomAccessFile(testFile, "rw")
- fp.write(fileContent)
- fp.close()
-
- server = new BlockServer(new SparkConf, new BlockDataProvider {
- override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
- if (blockId == bufferBlockId) {
- Right(buf)
- } else if (blockId == fileBlockId) {
- Left(new FileSegment(testFile, 10, testFile.length - 25))
- } else {
- throw new Exception("Unknown block id " + blockId)
- }
- }
- })
-
- clientFactory = new BlockFetchingClientFactory(new SparkConf)
- }
-
- override def afterAll() = {
- server.stop()
- clientFactory.stop()
- }
-
- /** A ByteBuf for buffer_block */
- lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf)
-
- /** A ByteBuf for file_block */
- lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25)
-
- def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String]) =
- {
- val client = clientFactory.createClient(server.hostName, server.port)
- val sem = new Semaphore(0)
- val receivedBlockIds = Collections.synchronizedSet(new HashSet[String])
- val errorBlockIds = Collections.synchronizedSet(new HashSet[String])
- val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer])
-
- client.fetchBlocks(
- blockIds,
- new BlockClientListener {
- override def onFetchFailure(blockId: String, errorMsg: String): Unit = {
- errorBlockIds.add(blockId)
- sem.release()
- }
-
- override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = {
- receivedBlockIds.add(blockId)
- data.retain()
- receivedBuffers.add(data)
- sem.release()
- }
- }
- )
- if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) {
- fail("Timeout getting response from the server")
- }
- client.close()
- (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet)
- }
-
- test("fetch a ByteBuffer block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId))
- assert(blockIds === Set(bufferBlockId))
- assert(buffers.map(_.underlying) === Set(byteBufferBlockReference))
- assert(failBlockIds.isEmpty)
- buffers.foreach(_.release())
- }
-
- test("fetch a FileSegment block via zero-copy send") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId))
- assert(blockIds === Set(fileBlockId))
- assert(buffers.map(_.underlying) === Set(fileBlockReference))
- assert(failBlockIds.isEmpty)
- buffers.foreach(_.release())
- }
-
- test("fetch a non-existent block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block"))
- assert(blockIds.isEmpty)
- assert(buffers.isEmpty)
- assert(failBlockIds === Set("random-block"))
- }
-
- test("fetch both ByteBuffer block and FileSegment block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId))
- assert(blockIds === Set(bufferBlockId, fileBlockId))
- assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference))
- assert(failBlockIds.isEmpty)
- buffers.foreach(_.release())
- }
-
- test("fetch both ByteBuffer block and a non-existent block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block"))
- assert(blockIds === Set(bufferBlockId))
- assert(buffers.map(_.underlying) === Set(byteBufferBlockReference))
- assert(failBlockIds === Set("random-block"))
- buffers.foreach(_.release())
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala
deleted file mode 100644
index f629322ff6..0000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.nio.ByteBuffer
-
-import com.google.common.base.Charsets.UTF_8
-import io.netty.buffer.Unpooled
-import io.netty.channel.embedded.EmbeddedChannel
-
-import org.scalatest.{PrivateMethodTester, FunSuite}
-
-
-class BlockFetchingClientHandlerSuite extends FunSuite with PrivateMethodTester {
-
- test("handling block data (successful fetch)") {
- val blockId = "test_block"
- val blockData = "blahblahblahblahblah"
- val totalLength = 4 + blockId.length + blockData.length
-
- var parsedBlockId: String = ""
- var parsedBlockData: String = ""
- val handler = new BlockFetchingClientHandler
- handler.addRequest(blockId,
- new BlockClientListener {
- override def onFetchFailure(blockId: String, errorMsg: String): Unit = ???
- override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer): Unit = {
- parsedBlockId = bid
- val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining)
- refCntBuf.byteBuffer().get(bytes)
- parsedBlockData = new String(bytes, UTF_8)
- }
- }
- )
-
- val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests)
- assert(handler.invokePrivate(outstandingRequests()).size === 1)
-
- val channel = new EmbeddedChannel(handler)
- val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
- buf.putInt(totalLength)
- buf.putInt(blockId.length)
- buf.put(blockId.getBytes)
- buf.put(blockData.getBytes)
- buf.flip()
-
- channel.writeInbound(Unpooled.wrappedBuffer(buf))
- assert(parsedBlockId === blockId)
- assert(parsedBlockData === blockData)
-
- assert(handler.invokePrivate(outstandingRequests()).size === 0)
-
- channel.close()
- }
-
- test("handling error message (failed fetch)") {
- val blockId = "test_block"
- val errorMsg = "error erro5r error err4or error3 error6 error erro1r"
- val totalLength = 4 + blockId.length + errorMsg.length
-
- var parsedBlockId: String = ""
- var parsedErrorMsg: String = ""
- val handler = new BlockFetchingClientHandler
- handler.addRequest(blockId, new BlockClientListener {
- override def onFetchFailure(bid: String, msg: String) ={
- parsedBlockId = bid
- parsedErrorMsg = msg
- }
- override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer) = ???
- })
-
- val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests)
- assert(handler.invokePrivate(outstandingRequests()).size === 1)
-
- val channel = new EmbeddedChannel(handler)
- val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
- buf.putInt(totalLength)
- buf.putInt(-blockId.length)
- buf.put(blockId.getBytes)
- buf.put(errorMsg.getBytes)
- buf.flip()
-
- channel.writeInbound(Unpooled.wrappedBuffer(buf))
- assert(parsedBlockId === blockId)
- assert(parsedErrorMsg === errorMsg)
-
- assert(handler.invokePrivate(outstandingRequests()).size === 0)
-
- channel.close()
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala
deleted file mode 100644
index 3f8d0cf8f3..0000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import com.google.common.base.Charsets.UTF_8
-import io.netty.buffer.ByteBuf
-import io.netty.channel.embedded.EmbeddedChannel
-
-import org.scalatest.FunSuite
-
-class BlockHeaderEncoderSuite extends FunSuite {
-
- test("encode normal block data") {
- val blockId = "test_block"
- val channel = new EmbeddedChannel(new BlockHeaderEncoder)
- channel.writeOutbound(new BlockHeader(17, blockId, None))
- val out = channel.readOutbound().asInstanceOf[ByteBuf]
- assert(out.readInt() === 4 + blockId.length + 17)
- assert(out.readInt() === blockId.length)
-
- val blockIdBytes = new Array[Byte](blockId.length)
- out.readBytes(blockIdBytes)
- assert(new String(blockIdBytes, UTF_8) === blockId)
- assert(out.readableBytes() === 0)
-
- channel.close()
- }
-
- test("encode error message") {
- val blockId = "error_block"
- val errorMsg = "error encountered"
- val channel = new EmbeddedChannel(new BlockHeaderEncoder)
- channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg)))
- val out = channel.readOutbound().asInstanceOf[ByteBuf]
- assert(out.readInt() === 4 + blockId.length + errorMsg.length)
- assert(out.readInt() === -blockId.length)
-
- val blockIdBytes = new Array[Byte](blockId.length)
- out.readBytes(blockIdBytes)
- assert(new String(blockIdBytes, UTF_8) === blockId)
-
- val errorMsgBytes = new Array[Byte](errorMsg.length)
- out.readBytes(errorMsgBytes)
- assert(new String(errorMsgBytes, UTF_8) === errorMsg)
- assert(out.readableBytes() === 0)
-
- channel.close()
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala
deleted file mode 100644
index 3239c710f1..0000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import java.io.{RandomAccessFile, File}
-import java.nio.ByteBuffer
-
-import io.netty.buffer.{Unpooled, ByteBuf}
-import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler, DefaultFileRegion}
-import io.netty.channel.embedded.EmbeddedChannel
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.storage.{BlockDataProvider, FileSegment}
-
-
-class BlockServerHandlerSuite extends FunSuite {
-
- test("ByteBuffer block") {
- val expectedBlockId = "test_bytebuffer_block"
- val buf = ByteBuffer.allocate(10000)
- for (i <- 1 to 10000) {
- buf.put(i.toByte)
- }
- buf.flip()
-
- val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider {
- override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = Right(buf)
- }))
-
- channel.writeInbound(expectedBlockId)
- assert(channel.outboundMessages().size === 2)
-
- val out1 = channel.readOutbound().asInstanceOf[BlockHeader]
- val out2 = channel.readOutbound().asInstanceOf[ByteBuf]
-
- assert(out1.blockId === expectedBlockId)
- assert(out1.blockSize === buf.remaining)
- assert(out1.error === None)
-
- assert(out2.equals(Unpooled.wrappedBuffer(buf)))
-
- channel.close()
- }
-
- test("FileSegment block via zero-copy") {
- val expectedBlockId = "test_file_block"
-
- // Create random file data
- val fileContent = new Array[Byte](1024)
- scala.util.Random.nextBytes(fileContent)
- val testFile = File.createTempFile("netty-test-file", "txt")
- val fp = new RandomAccessFile(testFile, "rw")
- fp.write(fileContent)
- fp.close()
-
- val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider {
- override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
- Left(new FileSegment(testFile, 15, testFile.length - 25))
- }
- }))
-
- channel.writeInbound(expectedBlockId)
- assert(channel.outboundMessages().size === 2)
-
- val out1 = channel.readOutbound().asInstanceOf[BlockHeader]
- val out2 = channel.readOutbound().asInstanceOf[DefaultFileRegion]
-
- assert(out1.blockId === expectedBlockId)
- assert(out1.blockSize === testFile.length - 25)
- assert(out1.error === None)
-
- assert(out2.count === testFile.length - 25)
- assert(out2.position === 15)
- }
-
- test("pipeline exception propagation") {
- val blockServerHandler = new BlockServerHandler(new BlockDataProvider {
- override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = ???
- })
- val exceptionHandler = new SimpleChannelInboundHandler[String]() {
- override def channelRead0(ctx: ChannelHandlerContext, msg: String): Unit = {
- throw new Exception("this is an error")
- }
- }
-
- val channel = new EmbeddedChannel(exceptionHandler, blockServerHandler)
- assert(channel.isOpen)
- channel.writeInbound("a message to trigger the error")
- assert(!channel.isOpen)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
new file mode 100644
index 0000000000..0ade1bab18
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import java.io.{EOFException, OutputStream, InputStream}
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
+
+
+/**
+ * A serializer implementation that always return a single element in a deserialization stream.
+ */
+class TestSerializer extends Serializer {
+ override def newInstance() = new TestSerializerInstance
+}
+
+
+class TestSerializerInstance extends SerializerInstance {
+ override def serialize[T: ClassTag](t: T): ByteBuffer = ???
+
+ override def serializeStream(s: OutputStream): SerializationStream = ???
+
+ override def deserializeStream(s: InputStream) = new TestDeserializationStream
+
+ override def deserialize[T: ClassTag](bytes: ByteBuffer): T = ???
+
+ override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = ???
+}
+
+
+class TestDeserializationStream extends DeserializationStream {
+
+ private var count = 0
+
+ override def readObject[T: ClassTag](): T = {
+ count += 1
+ if (count == 2) {
+ throw new EOFException
+ }
+ new Object().asInstanceOf[T]
+ }
+
+ override def close(): Unit = {}
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
index ba47fe5e25..6790388f96 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.FunSuite
import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf}
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.FileShuffleBlockManager
import org.apache.spark.storage.{ShuffleBlockId, FileSegment}
@@ -36,9 +36,9 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) {
assert(buffer.isInstanceOf[FileSegmentManagedBuffer])
val segment = buffer.asInstanceOf[FileSegmentManagedBuffer]
- assert(expected.file.getCanonicalPath === segment.file.getCanonicalPath)
- assert(expected.offset === segment.offset)
- assert(expected.length === segment.length)
+ assert(expected.file.getCanonicalPath === segment.getFile.getCanonicalPath)
+ assert(expected.offset === segment.getOffset)
+ assert(expected.length === segment.getLength)
}
test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") {
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index a8c049d749..4e502cf65e 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -17,6 +17,10 @@
package org.apache.spark.storage
+import java.util.concurrent.Semaphore
+
+import scala.concurrent.future
+import scala.concurrent.ExecutionContext.Implicits.global
import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
@@ -27,38 +31,64 @@ import org.mockito.stubbing.Answer
import org.scalatest.FunSuite
+import org.apache.spark.{SparkConf, TaskContext}
+import org.apache.spark.network._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.serializer.TestSerializer
+
class ShuffleBlockFetcherIteratorSuite extends FunSuite {
+ // Some of the tests are quite tricky because we are testing the cleanup behavior
+ // in the presence of faults.
- test("handle local read failures in BlockManager") {
+ /** Creates a mock [[BlockTransferService]] that returns data from the given map. */
+ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = {
val transfer = mock(classOf[BlockTransferService])
- val blockManager = mock(classOf[BlockManager])
- doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
-
- val blIds = Array[BlockId](
- ShuffleBlockId(0,0,0),
- ShuffleBlockId(0,1,0),
- ShuffleBlockId(0,2,0),
- ShuffleBlockId(0,3,0),
- ShuffleBlockId(0,4,0))
-
- val optItr = mock(classOf[Option[Iterator[Any]]])
- val answer = new Answer[Option[Iterator[Any]]] {
- override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] {
- throw new Exception
+ when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocation: InvocationOnMock): Unit = {
+ val blocks = invocation.getArguments()(2).asInstanceOf[Seq[String]]
+ val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
+
+ for (blockId <- blocks) {
+ if (data.contains(BlockId(blockId))) {
+ listener.onBlockFetchSuccess(blockId, data(BlockId(blockId)))
+ } else {
+ listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId))
+ }
+ }
}
+ })
+ transfer
+ }
+
+ private val conf = new SparkConf
+
+ test("successful 3 local reads + 2 remote reads") {
+ val blockManager = mock(classOf[BlockManager])
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+
+ // Make sure blockManager.getBlockData would return the blocks
+ val localBlocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]))
+ localBlocks.foreach { case (blockId, buf) =>
+ doReturn(buf).when(blockManager).getBlockData(meq(blockId))
}
- // 3rd block is going to fail
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
- doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
+ // Make sure remote blocks would return
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val remoteBlocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer])
+ )
+
+ val transfer = createMockTransfer(remoteBlocks)
- val bmId = BlockManagerId("test-client", "test-client", 1)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
+ (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq),
+ (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)
)
val iterator = new ShuffleBlockFetcherIterator(
@@ -66,118 +96,145 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
transfer,
blockManager,
blocksByAddress,
- null,
+ new TestSerializer,
48 * 1024 * 1024)
- // Without exhausting the iterator, the iterator should be lazy and not call
- // getLocalShuffleFromDisk.
- verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
- // the 2nd element of the tuple returned by iterator.next should be defined when
- // fetching successfully
- assert(iterator.next()._2.isDefined,
- "1st element should be defined but is not actually defined")
- verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
- assert(iterator.next()._2.isDefined,
- "2nd element should be defined but is not actually defined")
- verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
- // 3rd fetch should be failed
- intercept[Exception] {
- iterator.next()
+ // 3 local blocks fetched in initialization
+ verify(blockManager, times(3)).getBlockData(any())
+
+ for (i <- 0 until 5) {
+ assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements")
+ val (blockId, subIterator) = iterator.next()
+ assert(subIterator.isDefined,
+ s"iterator should have 5 elements defined but actually has $i elements")
+
+ // Make sure we release the buffer once the iterator is exhausted.
+ val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
+ verify(mockBuf, times(0)).release()
+ subIterator.get.foreach(_ => Unit) // exhaust the iterator
+ verify(mockBuf, times(1)).release()
}
- verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any())
+
+ // 3 local blocks, and 2 remote blocks
+ // (but from the same block manager so one call to fetchBlocks)
+ verify(blockManager, times(3)).getBlockData(any())
+ verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any())
}
- test("handle local read successes") {
- val transfer = mock(classOf[BlockTransferService])
+ test("release current unexhausted buffer in case the task completes early") {
val blockManager = mock(classOf[BlockManager])
- doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
-
- val blIds = Array[BlockId](
- ShuffleBlockId(0,0,0),
- ShuffleBlockId(0,1,0),
- ShuffleBlockId(0,2,0),
- ShuffleBlockId(0,3,0),
- ShuffleBlockId(0,4,0))
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+
+ // Make sure remote blocks would return
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val blocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])
+ )
- val optItr = mock(classOf[Option[Iterator[Any]]])
+ // Semaphore to coordinate event sequence in two different threads.
+ val sem = new Semaphore(0)
- // All blocks should be fetched successfully
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
+ val transfer = mock(classOf[BlockTransferService])
+ when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocation: InvocationOnMock): Unit = {
+ val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
+ future {
+ // Return the first two blocks, and wait till task completion before returning the 3rd one
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
+ sem.acquire()
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
+ }
+ }
+ })
- val bmId = BlockManagerId("test-client", "test-client", 1)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
- )
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+ val taskContext = new TaskContextImpl(0, 0, 0)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0),
+ taskContext,
transfer,
blockManager,
blocksByAddress,
- null,
+ new TestSerializer,
48 * 1024 * 1024)
- // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk.
- verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 1st element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 2nd element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 3rd element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 4th element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 5th element is not actually defined")
-
- verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any())
+ // Exhaust the first block, and then it should be released.
+ iterator.next()._2.get.foreach(_ => Unit)
+ verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release()
+
+ // Get the 2nd block but do not exhaust the iterator
+ val subIter = iterator.next()._2.get
+
+ // Complete the task; then the 2nd block buffer should be exhausted
+ verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release()
+ taskContext.markTaskCompleted()
+ verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release()
+
+ // The 3rd block should not be retained because the iterator is already in zombie state
+ sem.release()
+ verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).retain()
+ verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release()
}
- test("handle remote fetch failures in BlockTransferService") {
+ test("fail all blocks if any of the remote request fails") {
+ val blockManager = mock(classOf[BlockManager])
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+
+ // Make sure remote blocks would return
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val blocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])
+ )
+
+ // Semaphore to coordinate event sequence in two different threads.
+ val sem = new Semaphore(0)
+
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
- listener.onBlockFetchFailure(new Exception("blah"))
+ future {
+ // Return the first block, and then fail.
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+ listener.onBlockFetchFailure(
+ ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah"))
+ listener.onBlockFetchFailure(
+ ShuffleBlockId(0, 2, 0).toString, new BlockNotFoundException("blah"))
+ sem.release()
+ }
}
})
- val blockManager = mock(classOf[BlockManager])
-
- when(blockManager.blockManagerId).thenReturn(BlockManagerId("test-client", "test-client", 1))
-
- val blId1 = ShuffleBlockId(0, 0, 0)
- val blId2 = ShuffleBlockId(0, 1, 0)
- val bmId = BlockManagerId("test-server", "test-server", 1)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, Seq((blId1, 1L), (blId2, 1L))))
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+ val taskContext = new TaskContextImpl(0, 0, 0)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0),
+ taskContext,
transfer,
blockManager,
blocksByAddress,
- null,
+ new TestSerializer,
48 * 1024 * 1024)
- iterator.foreach { case (_, iterOption) =>
- assert(!iterOption.isDefined)
- }
+ // Continue only after the mock calls onBlockFetchFailure
+ sem.acquire()
+
+ // The first block should be defined, and the last two are not defined (due to failure)
+ assert(iterator.next()._2.isDefined === true)
+ assert(iterator.next()._2.isDefined === false)
+ assert(iterator.next()._2.isDefined === false)
}
}