aboutsummaryrefslogtreecommitdiff
path: root/network/common
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-10-29 11:27:07 -0700
committerReynold Xin <rxin@databricks.com>2014-10-29 11:27:07 -0700
commitdff015533dd7b01b5e392f1ac5f3837e0a65f3f4 (patch)
tree4314afec5436543c277a04d034a94aac200a23d1 /network/common
parent51ce997355465fc5c29d0e49b92f9bae0bab90ed (diff)
downloadspark-dff015533dd7b01b5e392f1ac5f3837e0a65f3f4.tar.gz
spark-dff015533dd7b01b5e392f1ac5f3837e0a65f3f4.tar.bz2
spark-dff015533dd7b01b5e392f1ac5f3837e0a65f3f4.zip
[SPARK-3453] Netty-based BlockTransferService, extracted from Spark core
This PR encapsulates #2330, which is itself a continuation of #2240. The first goal of this PR is to provide an alternate, simpler implementation of the ConnectionManager which is based on Netty. In addition to this goal, however, we want to resolve [SPARK-3796](https://issues.apache.org/jira/browse/SPARK-3796), which calls for a standalone shuffle service which can be integrated into the YARN NodeManager, Standalone Worker, or on its own. This PR makes the first step in this direction by ensuring that the actual Netty service is as small as possible and extracted from Spark core. Given this, we should be able to construct this standalone jar which can be included in other JVMs without incurring significant dependency or runtime issues. The actual work to ensure that such a standalone shuffle service would work in Spark will be left for a future PR, however. In order to minimize dependencies and allow for the service to be long-running (possibly much longer-running than Spark, and possibly having to support multiple version of Spark simultaneously), the entire service has been ported to Java, where we have full control over the binary compatibility of the components and do not depend on the Scala runtime or version. These issues: have been addressed by folding in #2330: SPARK-3453: Refactor Netty module to use BlockTransferService interface SPARK-3018: Release all buffers upon task completion/failure SPARK-3002: Create a connection pool and reuse clients across different threads SPARK-3017: Integration tests and unit tests for connection failures SPARK-3049: Make sure client doesn't block when server/connection has error(s) SPARK-3502: SO_RCVBUF and SO_SNDBUF should be bootstrap childOption, not option SPARK-3503: Disable thread local cache in PooledByteBufAllocator TODO before mergeable: - [x] Implement uploadBlock() - [x] Unit tests for RPC side of code - [x] Performance testing (see comments [here](https://github.com/apache/spark/pull/2753#issuecomment-59475022)) - [x] Turn OFF by default (currently on for unit testing) Author: Reynold Xin <rxin@apache.org> Author: Aaron Davidson <aaron@databricks.com> Author: cocoatomo <cocoatomo77@gmail.com> Author: Patrick Wendell <pwendell@gmail.com> Author: Prashant Sharma <prashant.s@imaginea.com> Author: Davies Liu <davies.liu@gmail.com> Author: Anand Avati <avati@redhat.com> Closes #2753 from aarondav/netty and squashes the following commits: cadfd28 [Aaron Davidson] Turn netty off by default d7be11b [Aaron Davidson] Turn netty on by default 4a204b8 [Aaron Davidson] Fail block fetches if client connection fails 2b0d1c0 [Aaron Davidson] 100ch 0c5bca2 [Aaron Davidson] Merge branch 'master' of https://github.com/apache/spark into netty 14e37f7 [Aaron Davidson] Address Reynold's comments 8dfcceb [Aaron Davidson] Merge branch 'master' of https://github.com/apache/spark into netty 322dfc1 [Aaron Davidson] Address Reynold's comments, including major rename e5675a4 [Aaron Davidson] Fail outstanding RPCs as well ccd4959 [Aaron Davidson] Don't throw exception if client immediately fails 9da0bc1 [Aaron Davidson] Add RPC unit tests d236dfd [Aaron Davidson] Remove no-op serializer :) 7b7a26c [Aaron Davidson] Fix Nio compile issue dd420fd [Aaron Davidson] Merge branch 'master' of https://github.com/apache/spark into netty-test 939f276 [Aaron Davidson] Attempt to make comm. bidirectional aa58f67 [cocoatomo] [SPARK-3909][PySpark][Doc] A corrupted format in Sphinx documents and building warnings 8dc1ded [cocoatomo] [SPARK-3867][PySpark] ./python/run-tests failed when it run with Python 2.6 and unittest2 is not installed 5b5dbe6 [Prashant Sharma] [SPARK-2924] Required by scala 2.11, only one fun/ctor amongst overriden alternatives, can have default argument(s). 2c5d9dc [Patrick Wendell] HOTFIX: Fix build issue with Akka 2.3.4 upgrade. 020691e [Davies Liu] [SPARK-3886] [PySpark] use AutoBatchedSerializer by default ae4083a [Anand Avati] [SPARK-2805] Upgrade Akka to 2.3.4 29c6dcf [Aaron Davidson] [SPARK-3453] Netty-based BlockTransferService, extracted from Spark core f7e7568 [Reynold Xin] Fixed spark.shuffle.io.receiveBuffer setting. 5d98ce3 [Reynold Xin] Flip buffer. f6c220d [Reynold Xin] Merge with latest master. 407e59a [Reynold Xin] Fix style violation. a0518c7 [Reynold Xin] Implemented block uploads. 4b18db2 [Reynold Xin] Copy the buffer in fetchBlockSync. bec4ea2 [Reynold Xin] Removed OIO and added num threads settings. 1bdd7ee [Reynold Xin] Fixed tests. d68f328 [Reynold Xin] Logging close() in case close() fails. f63fb4c [Reynold Xin] Add more debug message. 6afc435 [Reynold Xin] Added logging. c066309 [Reynold Xin] Implement java.io.Closeable interface. 519d64d [Reynold Xin] Mark private package visibility and MimaExcludes. f0a16e9 [Reynold Xin] Fixed test hanging. 14323a5 [Reynold Xin] Removed BlockManager.getLocalShuffleFromDisk. b2f3281 [Reynold Xin] Added connection pooling. d23ed7b [Reynold Xin] Incorporated feedback from Norman: - use same pool for boss and worker - remove ioratio - disable caching of byte buf allocator - childoption sendbuf/receivebuf - fire exception through pipeline 9e0cb87 [Reynold Xin] Fixed BlockClientHandlerSuite 5cd33d7 [Reynold Xin] Fixed style violation. cb589ec [Reynold Xin] Added more test cases covering cleanup when fault happens in ShuffleBlockFetcherIteratorSuite 1be4e8e [Reynold Xin] Shorten NioManagedBuffer and NettyManagedBuffer class names. 108c9ed [Reynold Xin] Forgot to add TestSerializer to the commit list. b5c8d1f [Reynold Xin] Fixed ShuffleBlockFetcherIteratorSuite. 064747b [Reynold Xin] Reference count buffers and clean them up properly. 2b44cf1 [Reynold Xin] Added more documentation. 1760d32 [Reynold Xin] Use Epoll.isAvailable in BlockServer as well. 165eab1 [Reynold Xin] [SPARK-3453] Refactor Netty module to use BlockTransferService.
Diffstat (limited to 'network/common')
-rw-r--r--network/common/pom.xml94
-rw-r--r--network/common/src/main/java/org/apache/spark/network/TransportContext.java117
-rw-r--r--network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java154
-rw-r--r--network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java71
-rw-r--r--network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java76
-rw-r--r--network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java75
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java31
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java47
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java30
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java159
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java182
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java167
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java76
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java66
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java80
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java41
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Message.java58
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java70
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java80
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java25
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java25
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java74
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java81
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java72
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java73
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java104
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java36
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java38
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/StreamManager.java52
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java96
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java162
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportServer.java121
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java52
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/IOMode.java27
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java38
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java102
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportConf.java61
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java217
-rw-r--r--network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java28
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java86
-rw-r--r--network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java175
-rw-r--r--network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java34
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java104
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TestUtils.java30
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java102
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java115
46 files changed, 3804 insertions, 0 deletions
diff --git a/network/common/pom.xml b/network/common/pom.xml
new file mode 100644
index 0000000000..e3b7e32870
--- /dev/null
+++ b/network/common/pom.xml
@@ -0,0 +1,94 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent</artifactId>
+ <version>1.2.0-SNAPSHOT</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>network</artifactId>
+ <packaging>jar</packaging>
+ <name>Shuffle Streaming Service</name>
+ <url>http://spark.apache.org/</url>
+ <properties>
+ <sbt.project.name>network</sbt.project.name>
+ </properties>
+
+ <dependencies>
+ <!-- Core dependencies -->
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-all</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ </dependency>
+
+ <!-- Provided dependencies -->
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <scope>provided</scope>
+ </dependency>
+
+ <!-- Test dependencies -->
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+
+
+ <build>
+ <outputDirectory>target/java/classes</outputDirectory>
+ <testOutputDirectory>target/java/test-classes</testOutputDirectory>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-surefire-plugin</artifactId>
+ <version>2.17</version>
+ <configuration>
+ <skipTests>false</skipTests>
+ <includes>
+ <include>**/Test*.java</include>
+ <include>**/*Test.java</include>
+ <include>**/*Suite.java</include>
+ </includes>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
new file mode 100644
index 0000000000..854aa6685f
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import io.netty.channel.Channel;
+import io.netty.channel.socket.SocketChannel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.MessageDecoder;
+import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportChannelHandler;
+import org.apache.spark.network.server.TransportRequestHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to
+ * setup Netty Channel pipelines with a {@link org.apache.spark.network.server.TransportChannelHandler}.
+ *
+ * There are two communication protocols that the TransportClient provides, control-plane RPCs and
+ * data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the
+ * TransportContext (i.e., by a user-provided handler), and it is responsible for setting up streams
+ * which can be streamed through the data plane in chunks using zero-copy IO.
+ *
+ * The TransportServer and TransportClientFactory both create a TransportChannelHandler for each
+ * channel. As each TransportChannelHandler contains a TransportClient, this enables server
+ * processes to send messages back to the client on an existing channel.
+ */
+public class TransportContext {
+ private final Logger logger = LoggerFactory.getLogger(TransportContext.class);
+
+ private final TransportConf conf;
+ private final StreamManager streamManager;
+ private final RpcHandler rpcHandler;
+
+ private final MessageEncoder encoder;
+ private final MessageDecoder decoder;
+
+ public TransportContext(TransportConf conf, StreamManager streamManager, RpcHandler rpcHandler) {
+ this.conf = conf;
+ this.streamManager = streamManager;
+ this.rpcHandler = rpcHandler;
+ this.encoder = new MessageEncoder();
+ this.decoder = new MessageDecoder();
+ }
+
+ public TransportClientFactory createClientFactory() {
+ return new TransportClientFactory(this);
+ }
+
+ public TransportServer createServer() {
+ return new TransportServer(this);
+ }
+
+ /**
+ * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and
+ * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or
+ * response messages.
+ *
+ * @return Returns the created TransportChannelHandler, which includes a TransportClient that can
+ * be used to communicate on this channel. The TransportClient is directly associated with a
+ * ChannelHandler to ensure all users of the same channel get the same TransportClient object.
+ */
+ public TransportChannelHandler initializePipeline(SocketChannel channel) {
+ try {
+ TransportChannelHandler channelHandler = createChannelHandler(channel);
+ channel.pipeline()
+ .addLast("encoder", encoder)
+ .addLast("frameDecoder", NettyUtils.createFrameDecoder())
+ .addLast("decoder", decoder)
+ // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
+ // would require more logic to guarantee if this were not part of the same event loop.
+ .addLast("handler", channelHandler);
+ return channelHandler;
+ } catch (RuntimeException e) {
+ logger.error("Error while initializing Netty pipeline", e);
+ throw e;
+ }
+ }
+
+ /**
+ * Creates the server- and client-side handler which is used to handle both RequestMessages and
+ * ResponseMessages. The channel is expected to have been successfully created, though certain
+ * properties (such as the remoteAddress()) may not be available yet.
+ */
+ private TransportChannelHandler createChannelHandler(Channel channel) {
+ TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
+ TransportClient client = new TransportClient(channel, responseHandler);
+ TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
+ streamManager, rpcHandler);
+ return new TransportChannelHandler(client, responseHandler, requestHandler);
+ }
+
+ public TransportConf getConf() { return conf; }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
new file mode 100644
index 0000000000..89ed79bc63
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+
+import com.google.common.base.Objects;
+import com.google.common.io.ByteStreams;
+import io.netty.channel.DefaultFileRegion;
+
+import org.apache.spark.network.util.JavaUtils;
+
+/**
+ * A {@link ManagedBuffer} backed by a segment in a file.
+ */
+public final class FileSegmentManagedBuffer extends ManagedBuffer {
+
+ /**
+ * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889).
+ * Avoid unless there's a good reason not to.
+ */
+ // TODO: Make this configurable
+ private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024;
+
+ private final File file;
+ private final long offset;
+ private final long length;
+
+ public FileSegmentManagedBuffer(File file, long offset, long length) {
+ this.file = file;
+ this.offset = offset;
+ this.length = length;
+ }
+
+ @Override
+ public long size() {
+ return length;
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ FileChannel channel = null;
+ try {
+ channel = new RandomAccessFile(file, "r").getChannel();
+ // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead.
+ if (length < MIN_MEMORY_MAP_BYTES) {
+ ByteBuffer buf = ByteBuffer.allocate((int) length);
+ channel.position(offset);
+ while (buf.remaining() != 0) {
+ if (channel.read(buf) == -1) {
+ throw new IOException(String.format("Reached EOF before filling buffer\n" +
+ "offset=%s\nfile=%s\nbuf.remaining=%s",
+ offset, file.getAbsoluteFile(), buf.remaining()));
+ }
+ }
+ buf.flip();
+ return buf;
+ } else {
+ return channel.map(FileChannel.MapMode.READ_ONLY, offset, length);
+ }
+ } catch (IOException e) {
+ try {
+ if (channel != null) {
+ long size = channel.size();
+ throw new IOException("Error in reading " + this + " (actual file length " + size + ")",
+ e);
+ }
+ } catch (IOException ignored) {
+ // ignore
+ }
+ throw new IOException("Error in opening " + this, e);
+ } finally {
+ JavaUtils.closeQuietly(channel);
+ }
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ FileInputStream is = null;
+ try {
+ is = new FileInputStream(file);
+ ByteStreams.skipFully(is, offset);
+ return ByteStreams.limit(is, length);
+ } catch (IOException e) {
+ try {
+ if (is != null) {
+ long size = file.length();
+ throw new IOException("Error in reading " + this + " (actual file length " + size + ")",
+ e);
+ }
+ } catch (IOException ignored) {
+ // ignore
+ } finally {
+ JavaUtils.closeQuietly(is);
+ }
+ throw new IOException("Error in opening " + this, e);
+ } catch (RuntimeException e) {
+ JavaUtils.closeQuietly(is);
+ throw e;
+ }
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ FileChannel fileChannel = new FileInputStream(file).getChannel();
+ return new DefaultFileRegion(fileChannel, offset, length);
+ }
+
+ public File getFile() { return file; }
+
+ public long getOffset() { return offset; }
+
+ public long getLength() { return length; }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("file", file)
+ .add("offset", offset)
+ .add("length", length)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
new file mode 100644
index 0000000000..a415db593a
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+/**
+ * This interface provides an immutable view for data in the form of bytes. The implementation
+ * should specify how the data is provided:
+ *
+ * - {@link FileSegmentManagedBuffer}: data backed by part of a file
+ * - {@link NioManagedBuffer}: data backed by a NIO ByteBuffer
+ * - {@link NettyManagedBuffer}: data backed by a Netty ByteBuf
+ *
+ * The concrete buffer implementation might be managed outside the JVM garbage collector.
+ * For example, in the case of {@link NettyManagedBuffer}, the buffers are reference counted.
+ * In that case, if the buffer is going to be passed around to a different thread, retain/release
+ * should be called.
+ */
+public abstract class ManagedBuffer {
+
+ /** Number of bytes of the data. */
+ public abstract long size();
+
+ /**
+ * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
+ * returned ByteBuffer should not affect the content of this buffer.
+ */
+ // TODO: Deprecate this, usage may require expensive memory mapping or allocation.
+ public abstract ByteBuffer nioByteBuffer() throws IOException;
+
+ /**
+ * Exposes this buffer's data as an InputStream. The underlying implementation does not
+ * necessarily check for the length of bytes read, so the caller is responsible for making sure
+ * it does not go over the limit.
+ */
+ public abstract InputStream createInputStream() throws IOException;
+
+ /**
+ * Increment the reference count by one if applicable.
+ */
+ public abstract ManagedBuffer retain();
+
+ /**
+ * If applicable, decrement the reference count by one and deallocates the buffer if the
+ * reference count reaches zero.
+ */
+ public abstract ManagedBuffer release();
+
+ /**
+ * Convert the buffer into an Netty object, used to write the data out.
+ */
+ public abstract Object convertToNetty() throws IOException;
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
new file mode 100644
index 0000000000..c806bfa45b
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufInputStream;
+
+/**
+ * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}.
+ */
+public final class NettyManagedBuffer extends ManagedBuffer {
+ private final ByteBuf buf;
+
+ public NettyManagedBuffer(ByteBuf buf) {
+ this.buf = buf;
+ }
+
+ @Override
+ public long size() {
+ return buf.readableBytes();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return buf.nioBuffer();
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ return new ByteBufInputStream(buf);
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ buf.retain();
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ buf.release();
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return buf.duplicate();
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("buf", buf)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java
new file mode 100644
index 0000000000..f55b884bc4
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBufInputStream;
+import io.netty.buffer.Unpooled;
+
+/**
+ * A {@link ManagedBuffer} backed by {@link ByteBuffer}.
+ */
+public final class NioManagedBuffer extends ManagedBuffer {
+ private final ByteBuffer buf;
+
+ public NioManagedBuffer(ByteBuffer buf) {
+ this.buf = buf;
+ }
+
+ @Override
+ public long size() {
+ return buf.remaining();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return buf.duplicate();
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ return new ByteBufInputStream(Unpooled.wrappedBuffer(buf));
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return Unpooled.wrappedBuffer(buf);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("buf", buf)
+ .toString();
+ }
+}
+
diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java
new file mode 100644
index 0000000000..1fbdcd6780
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * General exception caused by a remote exception while fetching a chunk.
+ */
+public class ChunkFetchFailureException extends RuntimeException {
+ public ChunkFetchFailureException(String errorMsg, Throwable cause) {
+ super(errorMsg, cause);
+ }
+
+ public ChunkFetchFailureException(String errorMsg) {
+ super(errorMsg);
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java
new file mode 100644
index 0000000000..519e6cb470
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Callback for the result of a single chunk result. For a single stream, the callbacks are
+ * guaranteed to be called by the same thread in the same order as the requests for chunks were
+ * made.
+ *
+ * Note that if a general stream failure occurs, all outstanding chunk requests may be failed.
+ */
+public interface ChunkReceivedCallback {
+ /**
+ * Called upon receipt of a particular chunk.
+ *
+ * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this
+ * call returns. You must therefore either retain() the buffer or copy its contents before
+ * returning.
+ */
+ void onSuccess(int chunkIndex, ManagedBuffer buffer);
+
+ /**
+ * Called upon failure to fetch a particular chunk. Note that this may actually be called due
+ * to failure to fetch a prior chunk in this stream.
+ *
+ * After receiving a failure, the stream may or may not be valid. The client should not assume
+ * that the server's side of the stream has been closed.
+ */
+ void onFailure(int chunkIndex, Throwable e);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
new file mode 100644
index 0000000000..6ec960d795
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * Callback for the result of a single RPC. This will be invoked once with either success or
+ * failure.
+ */
+public interface RpcResponseCallback {
+ /** Successful serialized result from server. */
+ void onSuccess(byte[] response);
+
+ /** Exception either propagated from server or raised on client side. */
+ void onFailure(Throwable e);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
new file mode 100644
index 0000000000..b1732fcde2
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.Closeable;
+import java.util.UUID;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.base.Preconditions;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow
+ * efficient transfer of a large amount of data, broken up into chunks with size ranging from
+ * hundreds of KB to a few MB.
+ *
+ * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane),
+ * the actual setup of the streams is done outside the scope of the transport layer. The convenience
+ * method "sendRPC" is provided to enable control plane communication between the client and server
+ * to perform this setup.
+ *
+ * For example, a typical workflow might be:
+ * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100
+ * client.fetchChunk(streamId = 100, chunkIndex = 0, callback)
+ * client.fetchChunk(streamId = 100, chunkIndex = 1, callback)
+ * ...
+ * client.sendRPC(new CloseStream(100))
+ *
+ * Construct an instance of TransportClient using {@link TransportClientFactory}. A single
+ * TransportClient may be used for multiple streams, but any given stream must be restricted to a
+ * single client, in order to avoid out-of-order responses.
+ *
+ * NB: This class is used to make requests to the server, while {@link TransportResponseHandler} is
+ * responsible for handling responses from the server.
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+public class TransportClient implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(TransportClient.class);
+
+ private final Channel channel;
+ private final TransportResponseHandler handler;
+
+ public TransportClient(Channel channel, TransportResponseHandler handler) {
+ this.channel = Preconditions.checkNotNull(channel);
+ this.handler = Preconditions.checkNotNull(handler);
+ }
+
+ public boolean isActive() {
+ return channel.isOpen() || channel.isActive();
+ }
+
+ /**
+ * Requests a single chunk from the remote side, from the pre-negotiated streamId.
+ *
+ * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though
+ * some streams may not support this.
+ *
+ * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed
+ * to be returned in the same order that they were requested, assuming only a single
+ * TransportClient is used to fetch the chunks.
+ *
+ * @param streamId Identifier that refers to a stream in the remote StreamManager. This should
+ * be agreed upon by client and server beforehand.
+ * @param chunkIndex 0-based index of the chunk to fetch
+ * @param callback Callback invoked upon successful receipt of chunk, or upon any failure.
+ */
+ public void fetchChunk(
+ long streamId,
+ final int chunkIndex,
+ final ChunkReceivedCallback callback) {
+ final String serverAddr = NettyUtils.getRemoteAddress(channel);
+ final long startTime = System.currentTimeMillis();
+ logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr);
+
+ final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
+ handler.addFetchRequest(streamChunkId, callback);
+
+ channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr,
+ timeTaken);
+ } else {
+ String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId,
+ serverAddr, future.cause());
+ logger.error(errorMsg, future.cause());
+ handler.removeFetchRequest(streamChunkId);
+ callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause()));
+ channel.close();
+ }
+ }
+ });
+ }
+
+ /**
+ * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked
+ * with the server's response or upon any failure.
+ */
+ public void sendRpc(byte[] message, final RpcResponseCallback callback) {
+ final String serverAddr = NettyUtils.getRemoteAddress(channel);
+ final long startTime = System.currentTimeMillis();
+ logger.trace("Sending RPC to {}", serverAddr);
+
+ final long requestId = UUID.randomUUID().getLeastSignificantBits();
+ handler.addRpcRequest(requestId, callback);
+
+ channel.writeAndFlush(new RpcRequest(requestId, message)).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken);
+ } else {
+ String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
+ serverAddr, future.cause());
+ logger.error(errorMsg, future.cause());
+ handler.removeRpcRequest(requestId);
+ callback.onFailure(new RuntimeException(errorMsg, future.cause()));
+ channel.close();
+ }
+ }
+ });
+ }
+
+ @Override
+ public void close() {
+ // close is a local operation and should finish with milliseconds; timeout just to be safe
+ channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
new file mode 100644
index 0000000000..10eb9ef7a0
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.Closeable;
+import java.lang.reflect.Field;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicReference;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.util.internal.PlatformDependent;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.server.TransportChannelHandler;
+import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Factory for creating {@link TransportClient}s by using createClient.
+ *
+ * The factory maintains a connection pool to other hosts and should return the same
+ * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for
+ * all {@link TransportClient}s.
+ */
+public class TransportClientFactory implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
+
+ private final TransportContext context;
+ private final TransportConf conf;
+ private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool;
+
+ private final Class<? extends Channel> socketChannelClass;
+ private final EventLoopGroup workerGroup;
+
+ public TransportClientFactory(TransportContext context) {
+ this.context = context;
+ this.conf = context.getConf();
+ this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>();
+
+ IOMode ioMode = IOMode.valueOf(conf.ioMode());
+ this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
+ // TODO: Make thread pool name configurable.
+ this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client");
+ }
+
+ /**
+ * Create a new BlockFetchingClient connecting to the given remote host / port.
+ *
+ * This blocks until a connection is successfully established.
+ *
+ * Concurrency: This method is safe to call from multiple threads.
+ */
+ public TransportClient createClient(String remoteHost, int remotePort) throws TimeoutException {
+ // Get connection from the connection pool first.
+ // If it is not found or not active, create a new one.
+ final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
+ TransportClient cachedClient = connectionPool.get(address);
+ if (cachedClient != null && cachedClient.isActive()) {
+ return cachedClient;
+ } else if (cachedClient != null) {
+ connectionPool.remove(address, cachedClient); // Remove inactive clients.
+ }
+
+ logger.debug("Creating new connection to " + address);
+
+ Bootstrap bootstrap = new Bootstrap();
+ bootstrap.group(workerGroup)
+ .channel(socketChannelClass)
+ // Disable Nagle's Algorithm since we don't want packets to wait
+ .option(ChannelOption.TCP_NODELAY, true)
+ .option(ChannelOption.SO_KEEPALIVE, true)
+ .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs());
+
+ // Use pooled buffers to reduce temporary buffer allocation
+ bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator());
+
+ final AtomicReference<TransportClient> client = new AtomicReference<TransportClient>();
+
+ bootstrap.handler(new ChannelInitializer<SocketChannel>() {
+ @Override
+ public void initChannel(SocketChannel ch) {
+ TransportChannelHandler clientHandler = context.initializePipeline(ch);
+ client.set(clientHandler.getClient());
+ }
+ });
+
+ // Connect to the remote server
+ ChannelFuture cf = bootstrap.connect(address);
+ if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
+ throw new TimeoutException(
+ String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
+ } else if (cf.cause() != null) {
+ throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
+ }
+
+ // Successful connection
+ assert client.get() != null : "Channel future completed successfully with null client";
+ TransportClient oldClient = connectionPool.putIfAbsent(address, client.get());
+ if (oldClient == null) {
+ return client.get();
+ } else {
+ logger.debug("Two clients were created concurrently, second one will be disposed.");
+ client.get().close();
+ return oldClient;
+ }
+ }
+
+ /** Close all connections in the connection pool, and shutdown the worker thread pool. */
+ @Override
+ public void close() {
+ for (TransportClient client : connectionPool.values()) {
+ try {
+ client.close();
+ } catch (RuntimeException e) {
+ logger.warn("Ignoring exception during close", e);
+ }
+ }
+ connectionPool.clear();
+
+ if (workerGroup != null) {
+ workerGroup.shutdownGracefully();
+ }
+ }
+
+ /**
+ * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches
+ * are disabled because the ByteBufs are allocated by the event loop thread, but released by the
+ * executor thread rather than the event loop thread. Those thread-local caches actually delay
+ * the recycling of buffers, leading to larger memory usage.
+ */
+ private PooledByteBufAllocator createPooledByteBufAllocator() {
+ return new PooledByteBufAllocator(
+ PlatformDependent.directBufferPreferred(),
+ getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
+ getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
+ getPrivateStaticField("DEFAULT_PAGE_SIZE"),
+ getPrivateStaticField("DEFAULT_MAX_ORDER"),
+ 0, // tinyCacheSize
+ 0, // smallCacheSize
+ 0 // normalCacheSize
+ );
+ }
+
+ /** Used to get defaults from Netty's private static fields. */
+ private int getPrivateStaticField(String name) {
+ try {
+ Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name);
+ f.setAccessible(true);
+ return f.getInt(null);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
new file mode 100644
index 0000000000..d8965590b3
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import com.google.common.annotations.VisibleForTesting;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.ResponseMessage;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.server.MessageHandler;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * Handler that processes server responses, in response to requests issued from a
+ * [[TransportClient]]. It works by tracking the list of outstanding requests (and their callbacks).
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
+ private final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class);
+
+ private final Channel channel;
+
+ private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches;
+
+ private final Map<Long, RpcResponseCallback> outstandingRpcs;
+
+ public TransportResponseHandler(Channel channel) {
+ this.channel = channel;
+ this.outstandingFetches = new ConcurrentHashMap<StreamChunkId, ChunkReceivedCallback>();
+ this.outstandingRpcs = new ConcurrentHashMap<Long, RpcResponseCallback>();
+ }
+
+ public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
+ outstandingFetches.put(streamChunkId, callback);
+ }
+
+ public void removeFetchRequest(StreamChunkId streamChunkId) {
+ outstandingFetches.remove(streamChunkId);
+ }
+
+ public void addRpcRequest(long requestId, RpcResponseCallback callback) {
+ outstandingRpcs.put(requestId, callback);
+ }
+
+ public void removeRpcRequest(long requestId) {
+ outstandingRpcs.remove(requestId);
+ }
+
+ /**
+ * Fire the failure callback for all outstanding requests. This is called when we have an
+ * uncaught exception or pre-mature connection termination.
+ */
+ private void failOutstandingRequests(Throwable cause) {
+ for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : outstandingFetches.entrySet()) {
+ entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
+ }
+ for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) {
+ entry.getValue().onFailure(cause);
+ }
+
+ // It's OK if new fetches appear, as they will fail immediately.
+ outstandingFetches.clear();
+ outstandingRpcs.clear();
+ }
+
+ @Override
+ public void channelUnregistered() {
+ if (numOutstandingRequests() > 0) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ logger.error("Still have {} requests outstanding when connection from {} is closed",
+ numOutstandingRequests(), remoteAddress);
+ failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed"));
+ }
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) {
+ if (numOutstandingRequests() > 0) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ logger.error("Still have {} requests outstanding when connection from {} is closed",
+ numOutstandingRequests(), remoteAddress);
+ failOutstandingRequests(cause);
+ }
+ }
+
+ @Override
+ public void handle(ResponseMessage message) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ if (message instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
+ ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
+ if (listener == null) {
+ logger.warn("Ignoring response for block {} from {} since it is not outstanding",
+ resp.streamChunkId, remoteAddress);
+ resp.buffer.release();
+ } else {
+ outstandingFetches.remove(resp.streamChunkId);
+ listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer);
+ resp.buffer.release();
+ }
+ } else if (message instanceof ChunkFetchFailure) {
+ ChunkFetchFailure resp = (ChunkFetchFailure) message;
+ ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
+ if (listener == null) {
+ logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding",
+ resp.streamChunkId, remoteAddress, resp.errorString);
+ } else {
+ outstandingFetches.remove(resp.streamChunkId);
+ listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException(
+ "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString));
+ }
+ } else if (message instanceof RpcResponse) {
+ RpcResponse resp = (RpcResponse) message;
+ RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+ if (listener == null) {
+ logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
+ resp.requestId, remoteAddress, resp.response.length);
+ } else {
+ outstandingRpcs.remove(resp.requestId);
+ listener.onSuccess(resp.response);
+ }
+ } else if (message instanceof RpcFailure) {
+ RpcFailure resp = (RpcFailure) message;
+ RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+ if (listener == null) {
+ logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
+ resp.requestId, remoteAddress, resp.errorString);
+ } else {
+ outstandingRpcs.remove(resp.requestId);
+ listener.onFailure(new RuntimeException(resp.errorString));
+ }
+ } else {
+ throw new IllegalStateException("Unknown response type: " + message.type());
+ }
+ }
+
+ /** Returns total number of outstanding requests (fetch requests + rpcs) */
+ @VisibleForTesting
+ public int numOutstandingRequests() {
+ return outstandingFetches.size() + outstandingRpcs.size();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
new file mode 100644
index 0000000000..152af98ced
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk.
+ */
+public final class ChunkFetchFailure implements ResponseMessage {
+ public final StreamChunkId streamChunkId;
+ public final String errorString;
+
+ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) {
+ this.streamChunkId = streamChunkId;
+ this.errorString = errorString;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchFailure; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
+ buf.writeInt(errorBytes.length);
+ buf.writeBytes(errorBytes);
+ }
+
+ public static ChunkFetchFailure decode(ByteBuf buf) {
+ StreamChunkId streamChunkId = StreamChunkId.decode(buf);
+ int numErrorStringBytes = buf.readInt();
+ byte[] errorBytes = new byte[numErrorStringBytes];
+ buf.readBytes(errorBytes);
+ return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchFailure) {
+ ChunkFetchFailure o = (ChunkFetchFailure) other;
+ return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .add("errorString", errorString)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
new file mode 100644
index 0000000000..980947cf13
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single
+ * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
+ */
+public final class ChunkFetchRequest implements RequestMessage {
+ public final StreamChunkId streamChunkId;
+
+ public ChunkFetchRequest(StreamChunkId streamChunkId) {
+ this.streamChunkId = streamChunkId;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchRequest; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength();
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ }
+
+ public static ChunkFetchRequest decode(ByteBuf buf) {
+ return new ChunkFetchRequest(StreamChunkId.decode(buf));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchRequest) {
+ ChunkFetchRequest o = (ChunkFetchRequest) other;
+ return streamChunkId.equals(o.streamChunkId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
new file mode 100644
index 0000000000..ff4936470c
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched.
+ *
+ * Note that the server-side encoding of this messages does NOT include the buffer itself, as this
+ * may be written by Netty in a more efficient manner (i.e., zero-copy write).
+ * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer.
+ */
+public final class ChunkFetchSuccess implements ResponseMessage {
+ public final StreamChunkId streamChunkId;
+ public final ManagedBuffer buffer;
+
+ public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
+ this.streamChunkId = streamChunkId;
+ this.buffer = buffer;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchSuccess; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength();
+ }
+
+ /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ }
+
+ /** Decoding uses the given ByteBuf as our data, and will retain() it. */
+ public static ChunkFetchSuccess decode(ByteBuf buf) {
+ StreamChunkId streamChunkId = StreamChunkId.decode(buf);
+ buf.retain();
+ NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate());
+ return new ChunkFetchSuccess(streamChunkId, managedBuf);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess o = (ChunkFetchSuccess) other;
+ return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .add("buffer", buffer)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java
new file mode 100644
index 0000000000..b4e299471b
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are
+ * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length.
+ *
+ * Encodable objects should provide a static "decode(ByteBuf)" method which is invoked by
+ * {@link MessageDecoder}. During decoding, if the object uses the ByteBuf as its data (rather than
+ * just copying data from it), then you must retain() the ByteBuf.
+ *
+ * Additionally, when adding a new Encodable Message, add it to {@link Message.Type}.
+ */
+public interface Encodable {
+ /** Number of bytes of the encoded form of this object. */
+ int encodedLength();
+
+ /**
+ * Serializes this object by writing into the given ByteBuf.
+ * This method must write exactly encodedLength() bytes.
+ */
+ void encode(ByteBuf buf);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
new file mode 100644
index 0000000000..d568370125
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+/** An on-the-wire transmittable message. */
+public interface Message extends Encodable {
+ /** Used to identify this request type. */
+ Type type();
+
+ /** Preceding every serialized Message is its type, which allows us to deserialize it. */
+ public static enum Type implements Encodable {
+ ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
+ RpcRequest(3), RpcResponse(4), RpcFailure(5);
+
+ private final byte id;
+
+ private Type(int id) {
+ assert id < 128 : "Cannot have more than 128 message types";
+ this.id = (byte) id;
+ }
+
+ public byte id() { return id; }
+
+ @Override public int encodedLength() { return 1; }
+
+ @Override public void encode(ByteBuf buf) { buf.writeByte(id); }
+
+ public static Type decode(ByteBuf buf) {
+ byte id = buf.readByte();
+ switch (id) {
+ case 0: return ChunkFetchRequest;
+ case 1: return ChunkFetchSuccess;
+ case 2: return ChunkFetchFailure;
+ case 3: return RpcRequest;
+ case 4: return RpcResponse;
+ case 5: return RpcFailure;
+ default: throw new IllegalArgumentException("Unknown message type: " + id);
+ }
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
new file mode 100644
index 0000000000..81f8d7f963
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageDecoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Decoder used by the client side to encode server-to-client responses.
+ * This encoder is stateless so it is safe to be shared by multiple threads.
+ */
+@ChannelHandler.Sharable
+public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
+
+ private final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);
+ @Override
+ public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
+ Message.Type msgType = Message.Type.decode(in);
+ Message decoded = decode(msgType, in);
+ assert decoded.type() == msgType;
+ logger.trace("Received message " + msgType + ": " + decoded);
+ out.add(decoded);
+ }
+
+ private Message decode(Message.Type msgType, ByteBuf in) {
+ switch (msgType) {
+ case ChunkFetchRequest:
+ return ChunkFetchRequest.decode(in);
+
+ case ChunkFetchSuccess:
+ return ChunkFetchSuccess.decode(in);
+
+ case ChunkFetchFailure:
+ return ChunkFetchFailure.decode(in);
+
+ case RpcRequest:
+ return RpcRequest.decode(in);
+
+ case RpcResponse:
+ return RpcResponse.decode(in);
+
+ case RpcFailure:
+ return RpcFailure.decode(in);
+
+ default:
+ throw new IllegalArgumentException("Unexpected message type: " + msgType);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
new file mode 100644
index 0000000000..4cb8becc3e
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageEncoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Encoder used by the server side to encode server-to-client responses.
+ * This encoder is stateless so it is safe to be shared by multiple threads.
+ */
+@ChannelHandler.Sharable
+public final class MessageEncoder extends MessageToMessageEncoder<Message> {
+
+ private final Logger logger = LoggerFactory.getLogger(MessageEncoder.class);
+
+ /***
+ * Encodes a Message by invoking its encode() method. For non-data messages, we will add one
+ * ByteBuf to 'out' containing the total frame length, the message type, and the message itself.
+ * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the
+ * data to 'out', in order to enable zero-copy transfer.
+ */
+ @Override
+ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) {
+ Object body = null;
+ long bodyLength = 0;
+
+ // Only ChunkFetchSuccesses have data besides the header.
+ // The body is used in order to enable zero-copy transfer for the payload.
+ if (in instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess resp = (ChunkFetchSuccess) in;
+ try {
+ bodyLength = resp.buffer.size();
+ body = resp.buffer.convertToNetty();
+ } catch (Exception e) {
+ // Re-encode this message as BlockFetchFailure.
+ logger.error(String.format("Error opening block %s for client %s",
+ resp.streamChunkId, ctx.channel().remoteAddress()), e);
+ encode(ctx, new ChunkFetchFailure(resp.streamChunkId, e.getMessage()), out);
+ return;
+ }
+ }
+
+ Message.Type msgType = in.type();
+ // All messages have the frame length, message type, and message itself.
+ int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
+ long frameLength = headerLength + bodyLength;
+ ByteBuf header = ctx.alloc().buffer(headerLength);
+ header.writeLong(frameLength);
+ msgType.encode(header);
+ in.encode(header);
+ assert header.writableBytes() == 0;
+
+ out.add(header);
+ if (body != null && bodyLength > 0) {
+ out.add(body);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java
new file mode 100644
index 0000000000..31b15bb17a
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import org.apache.spark.network.protocol.Message;
+
+/** Messages from the client to the server. */
+public interface RequestMessage extends Message {
+ // token interface
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java
new file mode 100644
index 0000000000..6edffd11cf
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import org.apache.spark.network.protocol.Message;
+
+/** Messages from the server to the client. */
+public interface ResponseMessage extends Message {
+ // token interface
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
new file mode 100644
index 0000000000..e239d4ffbd
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/** Response to {@link RpcRequest} for a failed RPC. */
+public final class RpcFailure implements ResponseMessage {
+ public final long requestId;
+ public final String errorString;
+
+ public RpcFailure(long requestId, String errorString) {
+ this.requestId = requestId;
+ this.errorString = errorString;
+ }
+
+ @Override
+ public Type type() { return Type.RpcFailure; }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
+ buf.writeInt(errorBytes.length);
+ buf.writeBytes(errorBytes);
+ }
+
+ public static RpcFailure decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ int numErrorStringBytes = buf.readInt();
+ byte[] errorBytes = new byte[numErrorStringBytes];
+ buf.readBytes(errorBytes);
+ return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcFailure) {
+ RpcFailure o = (RpcFailure) other;
+ return requestId == o.requestId && errorString.equals(o.errorString);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("errorString", errorString)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
new file mode 100644
index 0000000000..099e934ae0
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}.
+ * This will correspond to a single
+ * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
+ */
+public final class RpcRequest implements RequestMessage {
+ /** Used to link an RPC request with its response. */
+ public final long requestId;
+
+ /** Serialized message to send to remote RpcHandler. */
+ public final byte[] message;
+
+ public RpcRequest(long requestId, byte[] message) {
+ this.requestId = requestId;
+ this.message = message;
+ }
+
+ @Override
+ public Type type() { return Type.RpcRequest; }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4 + message.length;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ buf.writeInt(message.length);
+ buf.writeBytes(message);
+ }
+
+ public static RpcRequest decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ int messageLen = buf.readInt();
+ byte[] message = new byte[messageLen];
+ buf.readBytes(message);
+ return new RpcRequest(requestId, message);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcRequest) {
+ RpcRequest o = (RpcRequest) other;
+ return requestId == o.requestId && Arrays.equals(message, o.message);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("message", message)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
new file mode 100644
index 0000000000..ed47947832
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/** Response to {@link RpcRequest} for a successful RPC. */
+public final class RpcResponse implements ResponseMessage {
+ public final long requestId;
+ public final byte[] response;
+
+ public RpcResponse(long requestId, byte[] response) {
+ this.requestId = requestId;
+ this.response = response;
+ }
+
+ @Override
+ public Type type() { return Type.RpcResponse; }
+
+ @Override
+ public int encodedLength() { return 8 + 4 + response.length; }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ buf.writeInt(response.length);
+ buf.writeBytes(response);
+ }
+
+ public static RpcResponse decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ int responseLen = buf.readInt();
+ byte[] response = new byte[responseLen];
+ buf.readBytes(response);
+ return new RpcResponse(requestId, response);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcResponse) {
+ RpcResponse o = (RpcResponse) other;
+ return requestId == o.requestId && Arrays.equals(response, o.response);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("response", response)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java
new file mode 100644
index 0000000000..d46a263884
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+* Encapsulates a request for a particular chunk of a stream.
+*/
+public final class StreamChunkId implements Encodable {
+ public final long streamId;
+ public final int chunkIndex;
+
+ public StreamChunkId(long streamId, int chunkIndex) {
+ this.streamId = streamId;
+ this.chunkIndex = chunkIndex;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4;
+ }
+
+ public void encode(ByteBuf buffer) {
+ buffer.writeLong(streamId);
+ buffer.writeInt(chunkIndex);
+ }
+
+ public static StreamChunkId decode(ByteBuf buffer) {
+ assert buffer.readableBytes() >= 8 + 4;
+ long streamId = buffer.readLong();
+ int chunkIndex = buffer.readInt();
+ return new StreamChunkId(streamId, chunkIndex);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId, chunkIndex);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamChunkId) {
+ StreamChunkId o = (StreamChunkId) other;
+ return streamId == o.streamId && chunkIndex == o.chunkIndex;
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("chunkIndex", chunkIndex)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java
new file mode 100644
index 0000000000..9688705569
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually
+ * fetched as chunks by the client.
+ */
+public class DefaultStreamManager extends StreamManager {
+ private final Logger logger = LoggerFactory.getLogger(DefaultStreamManager.class);
+
+ private final AtomicLong nextStreamId;
+ private final Map<Long, StreamState> streams;
+
+ /** State of a single stream. */
+ private static class StreamState {
+ final Iterator<ManagedBuffer> buffers;
+
+ // Used to keep track of the index of the buffer that the user has retrieved, just to ensure
+ // that the caller only requests each chunk one at a time, in order.
+ int curChunk = 0;
+
+ StreamState(Iterator<ManagedBuffer> buffers) {
+ this.buffers = buffers;
+ }
+ }
+
+ public DefaultStreamManager() {
+ // For debugging purposes, start with a random stream id to help identifying different streams.
+ // This does not need to be globally unique, only unique to this class.
+ nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
+ streams = new ConcurrentHashMap<Long, StreamState>();
+ }
+
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ StreamState state = streams.get(streamId);
+ if (chunkIndex != state.curChunk) {
+ throw new IllegalStateException(String.format(
+ "Received out-of-order chunk index %s (expected %s)", chunkIndex, state.curChunk));
+ } else if (!state.buffers.hasNext()) {
+ throw new IllegalStateException(String.format(
+ "Requested chunk index beyond end %s", chunkIndex));
+ }
+ state.curChunk += 1;
+ ManagedBuffer nextChunk = state.buffers.next();
+
+ if (!state.buffers.hasNext()) {
+ logger.trace("Removing stream id {}", streamId);
+ streams.remove(streamId);
+ }
+
+ return nextChunk;
+ }
+
+ @Override
+ public void connectionTerminated(long streamId) {
+ // Release all remaining buffers.
+ StreamState state = streams.remove(streamId);
+ if (state != null && state.buffers != null) {
+ while (state.buffers.hasNext()) {
+ state.buffers.next().release();
+ }
+ }
+ }
+
+ /**
+ * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
+ * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a
+ * client connection is closed before the iterator is fully drained, then the remaining buffers
+ * will all be release()'d.
+ */
+ public long registerStream(Iterator<ManagedBuffer> buffers) {
+ long myStreamId = nextStreamId.getAndIncrement();
+ streams.put(myStreamId, new StreamState(buffers));
+ return myStreamId;
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
new file mode 100644
index 0000000000..b80c15106e
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import org.apache.spark.network.protocol.Message;
+
+/**
+ * Handles either request or response messages coming off of Netty. A MessageHandler instance
+ * is associated with a single Netty Channel (though it may have multiple clients on the same
+ * Channel.)
+ */
+public abstract class MessageHandler<T extends Message> {
+ /** Handles the receipt of a single message. */
+ public abstract void handle(T message);
+
+ /** Invoked when an exception was caught on the Channel. */
+ public abstract void exceptionCaught(Throwable cause);
+
+ /** Invoked when the channel this MessageHandler is on has been unregistered. */
+ public abstract void channelUnregistered();
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
new file mode 100644
index 0000000000..f54a696b8f
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+
+/**
+ * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
+ */
+public interface RpcHandler {
+ /**
+ * Receive a single RPC message. Any exception thrown while in this method will be sent back to
+ * the client in string form as a standard RPC failure.
+ *
+ * @param client A channel client which enables the handler to make requests back to the sender
+ * of this RPC.
+ * @param message The serialized bytes of the RPC.
+ * @param callback Callback which should be invoked exactly once upon success or failure of the
+ * RPC.
+ */
+ void receive(TransportClient client, byte[] message, RpcResponseCallback callback);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
new file mode 100644
index 0000000000..5a9a14a180
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * The StreamManager is used to fetch individual chunks from a stream. This is used in
+ * {@link TransportRequestHandler} in order to respond to fetchChunk() requests. Creation of the
+ * stream is outside the scope of the transport layer, but a given stream is guaranteed to be read
+ * by only one client connection, meaning that getChunk() for a particular stream will be called
+ * serially and that once the connection associated with the stream is closed, that stream will
+ * never be used again.
+ */
+public abstract class StreamManager {
+ /**
+ * Called in response to a fetchChunk() request. The returned buffer will be passed as-is to the
+ * client. A single stream will be associated with a single TCP connection, so this method
+ * will not be called in parallel for a particular stream.
+ *
+ * Chunks may be requested in any order, and requests may be repeated, but it is not required
+ * that implementations support this behavior.
+ *
+ * The returned ManagedBuffer will be release()'d after being written to the network.
+ *
+ * @param streamId id of a stream that has been previously registered with the StreamManager.
+ * @param chunkIndex 0-indexed chunk of the stream that's requested
+ */
+ public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);
+
+ /**
+ * Indicates that the TCP connection that was tied to the given stream has been terminated. After
+ * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned
+ * up.
+ */
+ public void connectionTerminated(long streamId) { }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
new file mode 100644
index 0000000000..e491367fa4
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.SimpleChannelInboundHandler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.Message;
+import org.apache.spark.network.protocol.RequestMessage;
+import org.apache.spark.network.protocol.ResponseMessage;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * The single Transport-level Channel handler which is used for delegating requests to the
+ * {@link TransportRequestHandler} and responses to the {@link TransportResponseHandler}.
+ *
+ * All channels created in the transport layer are bidirectional. When the Client initiates a Netty
+ * Channel with a RequestMessage (which gets handled by the Server's RequestHandler), the Server
+ * will produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server
+ * also gets a handle on the same Channel, so it may then begin to send RequestMessages to the
+ * Client.
+ * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler,
+ * for the Client's responses to the Server's requests.
+ */
+public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {
+ private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);
+
+ private final TransportClient client;
+ private final TransportResponseHandler responseHandler;
+ private final TransportRequestHandler requestHandler;
+
+ public TransportChannelHandler(
+ TransportClient client,
+ TransportResponseHandler responseHandler,
+ TransportRequestHandler requestHandler) {
+ this.client = client;
+ this.responseHandler = responseHandler;
+ this.requestHandler = requestHandler;
+ }
+
+ public TransportClient getClient() {
+ return client;
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
+ logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()),
+ cause);
+ requestHandler.exceptionCaught(cause);
+ responseHandler.exceptionCaught(cause);
+ ctx.close();
+ }
+
+ @Override
+ public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
+ try {
+ requestHandler.channelUnregistered();
+ } catch (RuntimeException e) {
+ logger.error("Exception from request handler while unregistering channel", e);
+ }
+ try {
+ responseHandler.channelUnregistered();
+ } catch (RuntimeException e) {
+ logger.error("Exception from response handler while unregistering channel", e);
+ }
+ super.channelUnregistered(ctx);
+ }
+
+ @Override
+ public void channelRead0(ChannelHandlerContext ctx, Message request) {
+ if (request instanceof RequestMessage) {
+ requestHandler.handle((RequestMessage) request);
+ } else {
+ responseHandler.handle((ResponseMessage) request);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
new file mode 100644
index 0000000000..352f865935
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.util.Set;
+
+import com.google.common.base.Throwables;
+import com.google.common.collect.Sets;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.RequestMessage;
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * A handler that processes requests from clients and writes chunk data back. Each handler is
+ * attached to a single Netty channel, and keeps track of which streams have been fetched via this
+ * channel, in order to clean them up if the channel is terminated (see #channelUnregistered).
+ *
+ * The messages should have been processed by the pipeline setup by {@link TransportServer}.
+ */
+public class TransportRequestHandler extends MessageHandler<RequestMessage> {
+ private final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class);
+
+ /** The Netty channel that this handler is associated with. */
+ private final Channel channel;
+
+ /** Client on the same channel allowing us to talk back to the requester. */
+ private final TransportClient reverseClient;
+
+ /** Returns each chunk part of a stream. */
+ private final StreamManager streamManager;
+
+ /** Handles all RPC messages. */
+ private final RpcHandler rpcHandler;
+
+ /** List of all stream ids that have been read on this handler, used for cleanup. */
+ private final Set<Long> streamIds;
+
+ public TransportRequestHandler(
+ Channel channel,
+ TransportClient reverseClient,
+ StreamManager streamManager,
+ RpcHandler rpcHandler) {
+ this.channel = channel;
+ this.reverseClient = reverseClient;
+ this.streamManager = streamManager;
+ this.rpcHandler = rpcHandler;
+ this.streamIds = Sets.newHashSet();
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) {
+ }
+
+ @Override
+ public void channelUnregistered() {
+ // Inform the StreamManager that these streams will no longer be read from.
+ for (long streamId : streamIds) {
+ streamManager.connectionTerminated(streamId);
+ }
+ }
+
+ @Override
+ public void handle(RequestMessage request) {
+ if (request instanceof ChunkFetchRequest) {
+ processFetchRequest((ChunkFetchRequest) request);
+ } else if (request instanceof RpcRequest) {
+ processRpcRequest((RpcRequest) request);
+ } else {
+ throw new IllegalArgumentException("Unknown request type: " + request);
+ }
+ }
+
+ private void processFetchRequest(final ChunkFetchRequest req) {
+ final String client = NettyUtils.getRemoteAddress(channel);
+ streamIds.add(req.streamChunkId.streamId);
+
+ logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId);
+
+ ManagedBuffer buf;
+ try {
+ buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
+ } catch (Exception e) {
+ logger.error(String.format(
+ "Error opening block %s for request from %s", req.streamChunkId, client), e);
+ respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e)));
+ return;
+ }
+
+ respond(new ChunkFetchSuccess(req.streamChunkId, buf));
+ }
+
+ private void processRpcRequest(final RpcRequest req) {
+ try {
+ rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() {
+ @Override
+ public void onSuccess(byte[] response) {
+ respond(new RpcResponse(req.requestId, response));
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
+ }
+ });
+ } catch (Exception e) {
+ logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
+ respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
+ }
+ }
+
+ /**
+ * Responds to a single message with some Encodable object. If a failure occurs while sending,
+ * it will be logged and the channel closed.
+ */
+ private void respond(final Encodable result) {
+ final String remoteAddress = channel.remoteAddress().toString();
+ channel.writeAndFlush(result).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ logger.trace(String.format("Sent result %s to client %s", result, remoteAddress));
+ } else {
+ logger.error(String.format("Error sending result %s to %s; closing connection",
+ result, remoteAddress), future.cause());
+ channel.close();
+ }
+ }
+ }
+ );
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
new file mode 100644
index 0000000000..243070750d
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.io.Closeable;
+import java.net.InetSocketAddress;
+import java.util.concurrent.TimeUnit;
+
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Server for the efficient, low-level streaming service.
+ */
+public class TransportServer implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(TransportServer.class);
+
+ private final TransportContext context;
+ private final TransportConf conf;
+
+ private ServerBootstrap bootstrap;
+ private ChannelFuture channelFuture;
+ private int port = -1;
+
+ public TransportServer(TransportContext context) {
+ this.context = context;
+ this.conf = context.getConf();
+
+ init();
+ }
+
+ public int getPort() {
+ if (port == -1) {
+ throw new IllegalStateException("Server not initialized");
+ }
+ return port;
+ }
+
+ private void init() {
+
+ IOMode ioMode = IOMode.valueOf(conf.ioMode());
+ EventLoopGroup bossGroup =
+ NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server");
+ EventLoopGroup workerGroup = bossGroup;
+
+ bootstrap = new ServerBootstrap()
+ .group(bossGroup, workerGroup)
+ .channel(NettyUtils.getServerChannelClass(ioMode))
+ .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+ .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT);
+
+ if (conf.backLog() > 0) {
+ bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());
+ }
+
+ if (conf.receiveBuf() > 0) {
+ bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf());
+ }
+
+ if (conf.sendBuf() > 0) {
+ bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf());
+ }
+
+ bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
+ @Override
+ protected void initChannel(SocketChannel ch) throws Exception {
+ context.initializePipeline(ch);
+ }
+ });
+
+ channelFuture = bootstrap.bind(new InetSocketAddress(conf.serverPort()));
+ channelFuture.syncUninterruptibly();
+
+ port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
+ logger.debug("Shuffle server started on port :" + port);
+ }
+
+ @Override
+ public void close() {
+ if (channelFuture != null) {
+ // close is a local operation and should finish with milliseconds; timeout just to be safe
+ channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS);
+ channelFuture = null;
+ }
+ if (bootstrap != null && bootstrap.group() != null) {
+ bootstrap.group().shutdownGracefully();
+ }
+ if (bootstrap != null && bootstrap.childGroup() != null) {
+ bootstrap.childGroup().shutdownGracefully();
+ }
+ bootstrap = null;
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java
new file mode 100644
index 0000000000..d944d9da1c
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.util.NoSuchElementException;
+
+/**
+ * Provides a mechanism for constructing a {@link TransportConf} using some sort of configuration.
+ */
+public abstract class ConfigProvider {
+ /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */
+ public abstract String get(String name);
+
+ public String get(String name, String defaultValue) {
+ try {
+ return get(name);
+ } catch (NoSuchElementException e) {
+ return defaultValue;
+ }
+ }
+
+ public int getInt(String name, int defaultValue) {
+ return Integer.parseInt(get(name, Integer.toString(defaultValue)));
+ }
+
+ public long getLong(String name, long defaultValue) {
+ return Long.parseLong(get(name, Long.toString(defaultValue)));
+ }
+
+ public double getDouble(String name, double defaultValue) {
+ return Double.parseDouble(get(name, Double.toString(defaultValue)));
+ }
+
+ public boolean getBoolean(String name, boolean defaultValue) {
+ return Boolean.parseBoolean(get(name, Boolean.toString(defaultValue)));
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java
new file mode 100644
index 0000000000..6b208d95bb
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+/**
+ * Selector for which form of low-level IO we should use.
+ * NIO is always available, while EPOLL is only available on Linux.
+ * AUTO is used to select EPOLL if it's available, or NIO otherwise.
+ */
+public enum IOMode {
+ NIO, EPOLL
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
new file mode 100644
index 0000000000..32ba3f5b07
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+import com.google.common.io.Closeables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class JavaUtils {
+ private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class);
+
+ /** Closes the given object, ignoring IOExceptions. */
+ public static void closeQuietly(Closeable closeable) {
+ try {
+ closeable.close();
+ } catch (IOException e) {
+ logger.error("IOException should not have been thrown.", e);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
new file mode 100644
index 0000000000..b187234119
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.util.concurrent.ThreadFactory;
+
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import io.netty.channel.Channel;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.ServerChannel;
+import io.netty.channel.epoll.Epoll;
+import io.netty.channel.epoll.EpollEventLoopGroup;
+import io.netty.channel.epoll.EpollServerSocketChannel;
+import io.netty.channel.epoll.EpollSocketChannel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.handler.codec.ByteToMessageDecoder;
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
+
+/**
+ * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO.
+ */
+public class NettyUtils {
+ /** Creates a Netty EventLoopGroup based on the IOMode. */
+ public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
+
+ ThreadFactory threadFactory = new ThreadFactoryBuilder()
+ .setDaemon(true)
+ .setNameFormat(threadPrefix + "-%d")
+ .build();
+
+ switch (mode) {
+ case NIO:
+ return new NioEventLoopGroup(numThreads, threadFactory);
+ case EPOLL:
+ return new EpollEventLoopGroup(numThreads, threadFactory);
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /** Returns the correct (client) SocketChannel class based on IOMode. */
+ public static Class<? extends Channel> getClientChannelClass(IOMode mode) {
+ switch (mode) {
+ case NIO:
+ return NioSocketChannel.class;
+ case EPOLL:
+ return EpollSocketChannel.class;
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /** Returns the correct ServerSocketChannel class based on IOMode. */
+ public static Class<? extends ServerChannel> getServerChannelClass(IOMode mode) {
+ switch (mode) {
+ case NIO:
+ return NioServerSocketChannel.class;
+ case EPOLL:
+ return EpollServerSocketChannel.class;
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /**
+ * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame.
+ * This is used before all decoders.
+ */
+ public static ByteToMessageDecoder createFrameDecoder() {
+ // maxFrameLength = 2G
+ // lengthFieldOffset = 0
+ // lengthFieldLength = 8
+ // lengthAdjustment = -8, i.e. exclude the 8 byte length itself
+ // initialBytesToStrip = 8, i.e. strip out the length field itself
+ return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8);
+ }
+
+ /** Returns the remote address on the channel or "<remote address>" if none exists. */
+ public static String getRemoteAddress(Channel channel) {
+ if (channel != null && channel.remoteAddress() != null) {
+ return channel.remoteAddress().toString();
+ }
+ return "<unknown remote>";
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
new file mode 100644
index 0000000000..80f65d9803
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+/**
+ * A central location that tracks all the settings we expose to users.
+ */
+public class TransportConf {
+ private final ConfigProvider conf;
+
+ public TransportConf(ConfigProvider conf) {
+ this.conf = conf;
+ }
+
+ /** Port the server listens on. Default to a random port. */
+ public int serverPort() { return conf.getInt("spark.shuffle.io.port", 0); }
+
+ /** IO mode: nio or epoll */
+ public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); }
+
+ /** Connect timeout in secs. Default 120 secs. */
+ public int connectionTimeoutMs() {
+ return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
+ }
+
+ /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */
+ public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); }
+
+ /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */
+ public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); }
+
+ /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */
+ public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); }
+
+ /**
+ * Receive buffer size (SO_RCVBUF).
+ * Note: the optimal size for receive buffer and send buffer should be
+ * latency * network_bandwidth.
+ * Assuming latency = 1ms, network_bandwidth = 10Gbps
+ * buffer size should be ~ 1.25MB
+ */
+ public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); }
+
+ /** Send buffer size (SO_SNDBUF). */
+ public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
new file mode 100644
index 0000000000..738dca9b6a
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.io.File;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.TransportConf;
+
+public class ChunkFetchIntegrationSuite {
+ static final long STREAM_ID = 1;
+ static final int BUFFER_CHUNK_INDEX = 0;
+ static final int FILE_CHUNK_INDEX = 1;
+
+ static TransportServer server;
+ static TransportClientFactory clientFactory;
+ static StreamManager streamManager;
+ static File testFile;
+
+ static ManagedBuffer bufferChunk;
+ static ManagedBuffer fileChunk;
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ int bufSize = 100000;
+ final ByteBuffer buf = ByteBuffer.allocate(bufSize);
+ for (int i = 0; i < bufSize; i ++) {
+ buf.put((byte) i);
+ }
+ buf.flip();
+ bufferChunk = new NioManagedBuffer(buf);
+
+ testFile = File.createTempFile("shuffle-test-file", "txt");
+ testFile.deleteOnExit();
+ RandomAccessFile fp = new RandomAccessFile(testFile, "rw");
+ byte[] fileContent = new byte[1024];
+ new Random().nextBytes(fileContent);
+ fp.write(fileContent);
+ fp.close();
+ fileChunk = new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25);
+
+ TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+ streamManager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ assertEquals(STREAM_ID, streamId);
+ if (chunkIndex == BUFFER_CHUNK_INDEX) {
+ return new NioManagedBuffer(buf);
+ } else if (chunkIndex == FILE_CHUNK_INDEX) {
+ return new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25);
+ } else {
+ throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex);
+ }
+ }
+ };
+ TransportContext context = new TransportContext(conf, streamManager, new NoOpRpcHandler());
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ server.close();
+ clientFactory.close();
+ testFile.delete();
+ }
+
+ class FetchResult {
+ public Set<Integer> successChunks;
+ public Set<Integer> failedChunks;
+ public List<ManagedBuffer> buffers;
+
+ public void releaseBuffers() {
+ for (ManagedBuffer buffer : buffers) {
+ buffer.release();
+ }
+ }
+ }
+
+ private FetchResult fetchChunks(List<Integer> chunkIndices) throws Exception {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ final Semaphore sem = new Semaphore(0);
+
+ final FetchResult res = new FetchResult();
+ res.successChunks = Collections.synchronizedSet(new HashSet<Integer>());
+ res.failedChunks = Collections.synchronizedSet(new HashSet<Integer>());
+ res.buffers = Collections.synchronizedList(new LinkedList<ManagedBuffer>());
+
+ ChunkReceivedCallback callback = new ChunkReceivedCallback() {
+ @Override
+ public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ buffer.retain();
+ res.successChunks.add(chunkIndex);
+ res.buffers.add(buffer);
+ sem.release();
+ }
+
+ @Override
+ public void onFailure(int chunkIndex, Throwable e) {
+ res.failedChunks.add(chunkIndex);
+ sem.release();
+ }
+ };
+
+ for (int chunkIndex : chunkIndices) {
+ client.fetchChunk(STREAM_ID, chunkIndex, callback);
+ }
+ if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server");
+ }
+ client.close();
+ return res;
+ }
+
+ @Test
+ public void fetchBufferChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchFileChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchNonExistentChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(12345));
+ assertTrue(res.successChunks.isEmpty());
+ assertEquals(res.failedChunks, Sets.newHashSet(12345));
+ assertTrue(res.buffers.isEmpty());
+ }
+
+ @Test
+ public void fetchBothChunks() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchChunkAndNonExistent() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX));
+ assertEquals(res.failedChunks, Sets.newHashSet(12345));
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk));
+ res.releaseBuffers();
+ }
+
+ private void assertBufferListsEqual(List<ManagedBuffer> list0, List<ManagedBuffer> list1)
+ throws Exception {
+ assertEquals(list0.size(), list1.size());
+ for (int i = 0; i < list0.size(); i ++) {
+ assertBuffersEqual(list0.get(i), list1.get(i));
+ }
+ }
+
+ private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception {
+ ByteBuffer nio0 = buffer0.nioByteBuffer();
+ ByteBuffer nio1 = buffer1.nioByteBuffer();
+
+ int len = nio0.remaining();
+ assertEquals(nio0.remaining(), nio1.remaining());
+ for (int i = 0; i < len; i ++) {
+ assertEquals(nio0.get(), nio1.get());
+ }
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java
new file mode 100644
index 0000000000..7aa37efc58
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java
@@ -0,0 +1,28 @@
+package org.apache.spark.network;/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.RpcHandler;
+
+/** Test RpcHandler which always returns a zero-sized success. */
+public class NoOpRpcHandler implements RpcHandler {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ callback.onSuccess(new byte[0]);
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
new file mode 100644
index 0000000000..43dc0cf8c7
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import io.netty.channel.embedded.EmbeddedChannel;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.spark.network.protocol.Message;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.MessageDecoder;
+import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.util.NettyUtils;
+
+public class ProtocolSuite {
+ private void testServerToClient(Message msg) {
+ EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder());
+ serverChannel.writeOutbound(msg);
+
+ EmbeddedChannel clientChannel = new EmbeddedChannel(
+ NettyUtils.createFrameDecoder(), new MessageDecoder());
+
+ while (!serverChannel.outboundMessages().isEmpty()) {
+ clientChannel.writeInbound(serverChannel.readOutbound());
+ }
+
+ assertEquals(1, clientChannel.inboundMessages().size());
+ assertEquals(msg, clientChannel.readInbound());
+ }
+
+ private void testClientToServer(Message msg) {
+ EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder());
+ clientChannel.writeOutbound(msg);
+
+ EmbeddedChannel serverChannel = new EmbeddedChannel(
+ NettyUtils.createFrameDecoder(), new MessageDecoder());
+
+ while (!clientChannel.outboundMessages().isEmpty()) {
+ serverChannel.writeInbound(clientChannel.readOutbound());
+ }
+
+ assertEquals(1, serverChannel.inboundMessages().size());
+ assertEquals(msg, serverChannel.readInbound());
+ }
+
+ @Test
+ public void requests() {
+ testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2)));
+ testClientToServer(new RpcRequest(12345, new byte[0]));
+ testClientToServer(new RpcRequest(12345, new byte[100]));
+ }
+
+ @Test
+ public void responses() {
+ testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10)));
+ testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0)));
+ testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error"));
+ testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), ""));
+ testServerToClient(new RpcResponse(12345, new byte[0]));
+ testServerToClient(new RpcResponse(12345, new byte[1000]));
+ testServerToClient(new RpcFailure(0, "this is an error"));
+ testServerToClient(new RpcFailure(0, ""));
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
new file mode 100644
index 0000000000..9f216dd2d7
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Sets;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.DefaultStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.TransportConf;
+
+public class RpcIntegrationSuite {
+ static TransportServer server;
+ static TransportClientFactory clientFactory;
+ static RpcHandler rpcHandler;
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+ rpcHandler = new RpcHandler() {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ String msg = new String(message, Charsets.UTF_8);
+ String[] parts = msg.split("/");
+ if (parts[0].equals("hello")) {
+ callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8));
+ } else if (parts[0].equals("return error")) {
+ callback.onFailure(new RuntimeException("Returned: " + parts[1]));
+ } else if (parts[0].equals("throw error")) {
+ throw new RuntimeException("Thrown: " + parts[1]);
+ }
+ }
+ };
+ TransportContext context = new TransportContext(conf, new DefaultStreamManager(), rpcHandler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ server.close();
+ clientFactory.close();
+ }
+
+ class RpcResult {
+ public Set<String> successMessages;
+ public Set<String> errorMessages;
+ }
+
+ private RpcResult sendRPC(String ... commands) throws Exception {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ final Semaphore sem = new Semaphore(0);
+
+ final RpcResult res = new RpcResult();
+ res.successMessages = Collections.synchronizedSet(new HashSet<String>());
+ res.errorMessages = Collections.synchronizedSet(new HashSet<String>());
+
+ RpcResponseCallback callback = new RpcResponseCallback() {
+ @Override
+ public void onSuccess(byte[] message) {
+ res.successMessages.add(new String(message, Charsets.UTF_8));
+ sem.release();
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ res.errorMessages.add(e.getMessage());
+ sem.release();
+ }
+ };
+
+ for (String command : commands) {
+ client.sendRpc(command.getBytes(Charsets.UTF_8), callback);
+ }
+
+ if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server");
+ }
+ client.close();
+ return res;
+ }
+
+ @Test
+ public void singleRPC() throws Exception {
+ RpcResult res = sendRPC("hello/Aaron");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!"));
+ assertTrue(res.errorMessages.isEmpty());
+ }
+
+ @Test
+ public void doubleRPC() throws Exception {
+ RpcResult res = sendRPC("hello/Aaron", "hello/Reynold");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!"));
+ assertTrue(res.errorMessages.isEmpty());
+ }
+
+ @Test
+ public void returnErrorRPC() throws Exception {
+ RpcResult res = sendRPC("return error/OK");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK"));
+ }
+
+ @Test
+ public void throwErrorRPC() throws Exception {
+ RpcResult res = sendRPC("throw error/uh-oh");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: uh-oh"));
+ }
+
+ @Test
+ public void doubleTrouble() throws Exception {
+ RpcResult res = sendRPC("return error/OK", "throw error/uh-oh");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK", "Thrown: uh-oh"));
+ }
+
+ @Test
+ public void sendSuccessAndFailure() throws Exception {
+ RpcResult res = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Bob!", "Hello, Builder!"));
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !"));
+ }
+
+ private void assertErrorsContain(Set<String> errors, Set<String> contains) {
+ assertEquals(contains.size(), errors.size());
+
+ Set<String> remainingErrors = Sets.newHashSet(errors);
+ for (String contain : contains) {
+ Iterator<String> it = remainingErrors.iterator();
+ boolean foundMatch = false;
+ while (it.hasNext()) {
+ if (it.next().contains(contain)) {
+ it.remove();
+ foundMatch = true;
+ break;
+ }
+ }
+ assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch);
+ }
+
+ assertTrue(remainingErrors.isEmpty());
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java b/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java
new file mode 100644
index 0000000000..f4e0a2426a
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.util.NoSuchElementException;
+
+import org.apache.spark.network.util.ConfigProvider;
+
+/** Uses System properties to obtain config values. */
+public class SystemPropertyConfigProvider extends ConfigProvider {
+ @Override
+ public String get(String name) {
+ String value = System.getProperty(name);
+ if (value == null) {
+ throw new NoSuchElementException(name);
+ }
+ return value;
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java
new file mode 100644
index 0000000000..38113a918f
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1).
+ *
+ * Used for testing.
+ */
+public class TestManagedBuffer extends ManagedBuffer {
+
+ private final int len;
+ private NettyManagedBuffer underlying;
+
+ public TestManagedBuffer(int len) {
+ Preconditions.checkArgument(len <= Byte.MAX_VALUE);
+ this.len = len;
+ byte[] byteArray = new byte[len];
+ for (int i = 0; i < len; i ++) {
+ byteArray[i] = (byte) i;
+ }
+ this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray));
+ }
+
+
+ @Override
+ public long size() {
+ return underlying.size();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return underlying.nioByteBuffer();
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ return underlying.createInputStream();
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ underlying.retain();
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ underlying.release();
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return underlying.convertToNetty();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ManagedBuffer) {
+ try {
+ ByteBuffer nioBuf = ((ManagedBuffer) other).nioByteBuffer();
+ if (nioBuf.remaining() != len) {
+ return false;
+ } else {
+ for (int i = 0; i < len; i ++) {
+ if (nioBuf.get() != i) {
+ return false;
+ }
+ }
+ return true;
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ return false;
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/TestUtils.java b/network/common/src/test/java/org/apache/spark/network/TestUtils.java
new file mode 100644
index 0000000000..56a2b805f1
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/TestUtils.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.net.InetAddress;
+
+public class TestUtils {
+ public static String getLocalHost() {
+ try {
+ return InetAddress.getLocalHost().getHostAddress();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
new file mode 100644
index 0000000000..3ef964616f
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.util.concurrent.TimeoutException;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.DefaultStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+public class TransportClientFactorySuite {
+ private TransportConf conf;
+ private TransportContext context;
+ private TransportServer server1;
+ private TransportServer server2;
+
+ @Before
+ public void setUp() {
+ conf = new TransportConf(new SystemPropertyConfigProvider());
+ StreamManager streamManager = new DefaultStreamManager();
+ RpcHandler rpcHandler = new NoOpRpcHandler();
+ context = new TransportContext(conf, streamManager, rpcHandler);
+ server1 = context.createServer();
+ server2 = context.createServer();
+ }
+
+ @After
+ public void tearDown() {
+ JavaUtils.closeQuietly(server1);
+ JavaUtils.closeQuietly(server2);
+ }
+
+ @Test
+ public void createAndReuseBlockClients() throws TimeoutException {
+ TransportClientFactory factory = context.createClientFactory();
+ TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ TransportClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
+ assertTrue(c1.isActive());
+ assertTrue(c3.isActive());
+ assertTrue(c1 == c2);
+ assertTrue(c1 != c3);
+ factory.close();
+ }
+
+ @Test
+ public void neverReturnInactiveClients() throws Exception {
+ TransportClientFactory factory = context.createClientFactory();
+ TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ c1.close();
+
+ long start = System.currentTimeMillis();
+ while (c1.isActive() && (System.currentTimeMillis() - start) < 3000) {
+ Thread.sleep(10);
+ }
+ assertFalse(c1.isActive());
+
+ TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ assertFalse(c1 == c2);
+ assertTrue(c2.isActive());
+ factory.close();
+ }
+
+ @Test
+ public void closeBlockClientsWithFactory() throws TimeoutException {
+ TransportClientFactory factory = context.createClientFactory();
+ TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
+ assertTrue(c1.isActive());
+ assertTrue(c2.isActive());
+ factory.close();
+ assertFalse(c1.isActive());
+ assertFalse(c2.isActive());
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
new file mode 100644
index 0000000000..17a03ebe88
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import io.netty.channel.local.LocalChannel;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamChunkId;
+
+public class TransportResponseHandlerSuite {
+ @Test
+ public void handleSuccessfulFetch() {
+ StreamChunkId streamChunkId = new StreamChunkId(1, 0);
+
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ handler.addFetchRequest(streamChunkId, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123)));
+ verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void handleFailedFetch() {
+ StreamChunkId streamChunkId = new StreamChunkId(1, 0);
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ handler.addFetchRequest(streamChunkId, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg"));
+ verify(callback, times(1)).onFailure(eq(0), (Throwable) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void clearAllOutstandingRequests() {
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ handler.addFetchRequest(new StreamChunkId(1, 0), callback);
+ handler.addFetchRequest(new StreamChunkId(1, 1), callback);
+ handler.addFetchRequest(new StreamChunkId(1, 2), callback);
+ assertEquals(3, handler.numOutstandingRequests());
+
+ handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12)));
+ handler.exceptionCaught(new Exception("duh duh duhhhh"));
+
+ // should fail both b2 and b3
+ verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any());
+ verify(callback, times(1)).onFailure(eq(1), (Throwable) any());
+ verify(callback, times(1)).onFailure(eq(2), (Throwable) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void handleSuccessfulRPC() {
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+ handler.addRpcRequest(12345, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored
+ assertEquals(1, handler.numOutstandingRequests());
+
+ byte[] arr = new byte[10];
+ handler.handle(new RpcResponse(12345, arr));
+ verify(callback, times(1)).onSuccess(eq(arr));
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void handleFailedRPC() {
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+ handler.addRpcRequest(12345, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new RpcFailure(54321, "uh-oh!")); // should be ignored
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new RpcFailure(12345, "oh no"));
+ verify(callback, times(1)).onFailure((Throwable) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+}