aboutsummaryrefslogtreecommitdiff
path: root/network
diff options
context:
space:
mode:
Diffstat (limited to 'network')
-rw-r--r--network/common/pom.xml94
-rw-r--r--network/common/src/main/java/org/apache/spark/network/TransportContext.java117
-rw-r--r--network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java154
-rw-r--r--network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java71
-rw-r--r--network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java76
-rw-r--r--network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java75
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java31
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java47
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java30
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java159
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java182
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java167
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java76
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java66
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java80
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java41
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Message.java58
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java70
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java80
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java25
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java25
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java74
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java81
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java72
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java73
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java104
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java36
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java38
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/StreamManager.java52
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java96
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java162
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportServer.java121
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java52
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/IOMode.java27
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java38
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java102
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportConf.java61
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java217
-rw-r--r--network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java28
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java86
-rw-r--r--network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java175
-rw-r--r--network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java34
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java104
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TestUtils.java30
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java102
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java115
46 files changed, 3804 insertions, 0 deletions
diff --git a/network/common/pom.xml b/network/common/pom.xml
new file mode 100644
index 0000000000..e3b7e32870
--- /dev/null
+++ b/network/common/pom.xml
@@ -0,0 +1,94 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent</artifactId>
+ <version>1.2.0-SNAPSHOT</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>network</artifactId>
+ <packaging>jar</packaging>
+ <name>Shuffle Streaming Service</name>
+ <url>http://spark.apache.org/</url>
+ <properties>
+ <sbt.project.name>network</sbt.project.name>
+ </properties>
+
+ <dependencies>
+ <!-- Core dependencies -->
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-all</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ </dependency>
+
+ <!-- Provided dependencies -->
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <scope>provided</scope>
+ </dependency>
+
+ <!-- Test dependencies -->
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+
+
+ <build>
+ <outputDirectory>target/java/classes</outputDirectory>
+ <testOutputDirectory>target/java/test-classes</testOutputDirectory>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-surefire-plugin</artifactId>
+ <version>2.17</version>
+ <configuration>
+ <skipTests>false</skipTests>
+ <includes>
+ <include>**/Test*.java</include>
+ <include>**/*Test.java</include>
+ <include>**/*Suite.java</include>
+ </includes>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
new file mode 100644
index 0000000000..854aa6685f
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import io.netty.channel.Channel;
+import io.netty.channel.socket.SocketChannel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.MessageDecoder;
+import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportChannelHandler;
+import org.apache.spark.network.server.TransportRequestHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to
+ * setup Netty Channel pipelines with a {@link org.apache.spark.network.server.TransportChannelHandler}.
+ *
+ * There are two communication protocols that the TransportClient provides, control-plane RPCs and
+ * data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the
+ * TransportContext (i.e., by a user-provided handler), and it is responsible for setting up streams
+ * which can be streamed through the data plane in chunks using zero-copy IO.
+ *
+ * The TransportServer and TransportClientFactory both create a TransportChannelHandler for each
+ * channel. As each TransportChannelHandler contains a TransportClient, this enables server
+ * processes to send messages back to the client on an existing channel.
+ */
+public class TransportContext {
+ private final Logger logger = LoggerFactory.getLogger(TransportContext.class);
+
+ private final TransportConf conf;
+ private final StreamManager streamManager;
+ private final RpcHandler rpcHandler;
+
+ private final MessageEncoder encoder;
+ private final MessageDecoder decoder;
+
+ public TransportContext(TransportConf conf, StreamManager streamManager, RpcHandler rpcHandler) {
+ this.conf = conf;
+ this.streamManager = streamManager;
+ this.rpcHandler = rpcHandler;
+ this.encoder = new MessageEncoder();
+ this.decoder = new MessageDecoder();
+ }
+
+ public TransportClientFactory createClientFactory() {
+ return new TransportClientFactory(this);
+ }
+
+ public TransportServer createServer() {
+ return new TransportServer(this);
+ }
+
+ /**
+ * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and
+ * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or
+ * response messages.
+ *
+ * @return Returns the created TransportChannelHandler, which includes a TransportClient that can
+ * be used to communicate on this channel. The TransportClient is directly associated with a
+ * ChannelHandler to ensure all users of the same channel get the same TransportClient object.
+ */
+ public TransportChannelHandler initializePipeline(SocketChannel channel) {
+ try {
+ TransportChannelHandler channelHandler = createChannelHandler(channel);
+ channel.pipeline()
+ .addLast("encoder", encoder)
+ .addLast("frameDecoder", NettyUtils.createFrameDecoder())
+ .addLast("decoder", decoder)
+ // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
+ // would require more logic to guarantee if this were not part of the same event loop.
+ .addLast("handler", channelHandler);
+ return channelHandler;
+ } catch (RuntimeException e) {
+ logger.error("Error while initializing Netty pipeline", e);
+ throw e;
+ }
+ }
+
+ /**
+ * Creates the server- and client-side handler which is used to handle both RequestMessages and
+ * ResponseMessages. The channel is expected to have been successfully created, though certain
+ * properties (such as the remoteAddress()) may not be available yet.
+ */
+ private TransportChannelHandler createChannelHandler(Channel channel) {
+ TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
+ TransportClient client = new TransportClient(channel, responseHandler);
+ TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
+ streamManager, rpcHandler);
+ return new TransportChannelHandler(client, responseHandler, requestHandler);
+ }
+
+ public TransportConf getConf() { return conf; }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
new file mode 100644
index 0000000000..89ed79bc63
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+
+import com.google.common.base.Objects;
+import com.google.common.io.ByteStreams;
+import io.netty.channel.DefaultFileRegion;
+
+import org.apache.spark.network.util.JavaUtils;
+
+/**
+ * A {@link ManagedBuffer} backed by a segment in a file.
+ */
+public final class FileSegmentManagedBuffer extends ManagedBuffer {
+
+ /**
+ * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889).
+ * Avoid unless there's a good reason not to.
+ */
+ // TODO: Make this configurable
+ private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024;
+
+ private final File file;
+ private final long offset;
+ private final long length;
+
+ public FileSegmentManagedBuffer(File file, long offset, long length) {
+ this.file = file;
+ this.offset = offset;
+ this.length = length;
+ }
+
+ @Override
+ public long size() {
+ return length;
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ FileChannel channel = null;
+ try {
+ channel = new RandomAccessFile(file, "r").getChannel();
+ // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead.
+ if (length < MIN_MEMORY_MAP_BYTES) {
+ ByteBuffer buf = ByteBuffer.allocate((int) length);
+ channel.position(offset);
+ while (buf.remaining() != 0) {
+ if (channel.read(buf) == -1) {
+ throw new IOException(String.format("Reached EOF before filling buffer\n" +
+ "offset=%s\nfile=%s\nbuf.remaining=%s",
+ offset, file.getAbsoluteFile(), buf.remaining()));
+ }
+ }
+ buf.flip();
+ return buf;
+ } else {
+ return channel.map(FileChannel.MapMode.READ_ONLY, offset, length);
+ }
+ } catch (IOException e) {
+ try {
+ if (channel != null) {
+ long size = channel.size();
+ throw new IOException("Error in reading " + this + " (actual file length " + size + ")",
+ e);
+ }
+ } catch (IOException ignored) {
+ // ignore
+ }
+ throw new IOException("Error in opening " + this, e);
+ } finally {
+ JavaUtils.closeQuietly(channel);
+ }
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ FileInputStream is = null;
+ try {
+ is = new FileInputStream(file);
+ ByteStreams.skipFully(is, offset);
+ return ByteStreams.limit(is, length);
+ } catch (IOException e) {
+ try {
+ if (is != null) {
+ long size = file.length();
+ throw new IOException("Error in reading " + this + " (actual file length " + size + ")",
+ e);
+ }
+ } catch (IOException ignored) {
+ // ignore
+ } finally {
+ JavaUtils.closeQuietly(is);
+ }
+ throw new IOException("Error in opening " + this, e);
+ } catch (RuntimeException e) {
+ JavaUtils.closeQuietly(is);
+ throw e;
+ }
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ FileChannel fileChannel = new FileInputStream(file).getChannel();
+ return new DefaultFileRegion(fileChannel, offset, length);
+ }
+
+ public File getFile() { return file; }
+
+ public long getOffset() { return offset; }
+
+ public long getLength() { return length; }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("file", file)
+ .add("offset", offset)
+ .add("length", length)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
new file mode 100644
index 0000000000..a415db593a
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+/**
+ * This interface provides an immutable view for data in the form of bytes. The implementation
+ * should specify how the data is provided:
+ *
+ * - {@link FileSegmentManagedBuffer}: data backed by part of a file
+ * - {@link NioManagedBuffer}: data backed by a NIO ByteBuffer
+ * - {@link NettyManagedBuffer}: data backed by a Netty ByteBuf
+ *
+ * The concrete buffer implementation might be managed outside the JVM garbage collector.
+ * For example, in the case of {@link NettyManagedBuffer}, the buffers are reference counted.
+ * In that case, if the buffer is going to be passed around to a different thread, retain/release
+ * should be called.
+ */
+public abstract class ManagedBuffer {
+
+ /** Number of bytes of the data. */
+ public abstract long size();
+
+ /**
+ * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
+ * returned ByteBuffer should not affect the content of this buffer.
+ */
+ // TODO: Deprecate this, usage may require expensive memory mapping or allocation.
+ public abstract ByteBuffer nioByteBuffer() throws IOException;
+
+ /**
+ * Exposes this buffer's data as an InputStream. The underlying implementation does not
+ * necessarily check for the length of bytes read, so the caller is responsible for making sure
+ * it does not go over the limit.
+ */
+ public abstract InputStream createInputStream() throws IOException;
+
+ /**
+ * Increment the reference count by one if applicable.
+ */
+ public abstract ManagedBuffer retain();
+
+ /**
+ * If applicable, decrement the reference count by one and deallocates the buffer if the
+ * reference count reaches zero.
+ */
+ public abstract ManagedBuffer release();
+
+ /**
+ * Convert the buffer into an Netty object, used to write the data out.
+ */
+ public abstract Object convertToNetty() throws IOException;
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
new file mode 100644
index 0000000000..c806bfa45b
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufInputStream;
+
+/**
+ * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}.
+ */
+public final class NettyManagedBuffer extends ManagedBuffer {
+ private final ByteBuf buf;
+
+ public NettyManagedBuffer(ByteBuf buf) {
+ this.buf = buf;
+ }
+
+ @Override
+ public long size() {
+ return buf.readableBytes();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return buf.nioBuffer();
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ return new ByteBufInputStream(buf);
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ buf.retain();
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ buf.release();
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return buf.duplicate();
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("buf", buf)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java
new file mode 100644
index 0000000000..f55b884bc4
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBufInputStream;
+import io.netty.buffer.Unpooled;
+
+/**
+ * A {@link ManagedBuffer} backed by {@link ByteBuffer}.
+ */
+public final class NioManagedBuffer extends ManagedBuffer {
+ private final ByteBuffer buf;
+
+ public NioManagedBuffer(ByteBuffer buf) {
+ this.buf = buf;
+ }
+
+ @Override
+ public long size() {
+ return buf.remaining();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return buf.duplicate();
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ return new ByteBufInputStream(Unpooled.wrappedBuffer(buf));
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return Unpooled.wrappedBuffer(buf);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("buf", buf)
+ .toString();
+ }
+}
+
diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java
new file mode 100644
index 0000000000..1fbdcd6780
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * General exception caused by a remote exception while fetching a chunk.
+ */
+public class ChunkFetchFailureException extends RuntimeException {
+ public ChunkFetchFailureException(String errorMsg, Throwable cause) {
+ super(errorMsg, cause);
+ }
+
+ public ChunkFetchFailureException(String errorMsg) {
+ super(errorMsg);
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java
new file mode 100644
index 0000000000..519e6cb470
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Callback for the result of a single chunk result. For a single stream, the callbacks are
+ * guaranteed to be called by the same thread in the same order as the requests for chunks were
+ * made.
+ *
+ * Note that if a general stream failure occurs, all outstanding chunk requests may be failed.
+ */
+public interface ChunkReceivedCallback {
+ /**
+ * Called upon receipt of a particular chunk.
+ *
+ * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this
+ * call returns. You must therefore either retain() the buffer or copy its contents before
+ * returning.
+ */
+ void onSuccess(int chunkIndex, ManagedBuffer buffer);
+
+ /**
+ * Called upon failure to fetch a particular chunk. Note that this may actually be called due
+ * to failure to fetch a prior chunk in this stream.
+ *
+ * After receiving a failure, the stream may or may not be valid. The client should not assume
+ * that the server's side of the stream has been closed.
+ */
+ void onFailure(int chunkIndex, Throwable e);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
new file mode 100644
index 0000000000..6ec960d795
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * Callback for the result of a single RPC. This will be invoked once with either success or
+ * failure.
+ */
+public interface RpcResponseCallback {
+ /** Successful serialized result from server. */
+ void onSuccess(byte[] response);
+
+ /** Exception either propagated from server or raised on client side. */
+ void onFailure(Throwable e);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
new file mode 100644
index 0000000000..b1732fcde2
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.Closeable;
+import java.util.UUID;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.base.Preconditions;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow
+ * efficient transfer of a large amount of data, broken up into chunks with size ranging from
+ * hundreds of KB to a few MB.
+ *
+ * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane),
+ * the actual setup of the streams is done outside the scope of the transport layer. The convenience
+ * method "sendRPC" is provided to enable control plane communication between the client and server
+ * to perform this setup.
+ *
+ * For example, a typical workflow might be:
+ * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100
+ * client.fetchChunk(streamId = 100, chunkIndex = 0, callback)
+ * client.fetchChunk(streamId = 100, chunkIndex = 1, callback)
+ * ...
+ * client.sendRPC(new CloseStream(100))
+ *
+ * Construct an instance of TransportClient using {@link TransportClientFactory}. A single
+ * TransportClient may be used for multiple streams, but any given stream must be restricted to a
+ * single client, in order to avoid out-of-order responses.
+ *
+ * NB: This class is used to make requests to the server, while {@link TransportResponseHandler} is
+ * responsible for handling responses from the server.
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+public class TransportClient implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(TransportClient.class);
+
+ private final Channel channel;
+ private final TransportResponseHandler handler;
+
+ public TransportClient(Channel channel, TransportResponseHandler handler) {
+ this.channel = Preconditions.checkNotNull(channel);
+ this.handler = Preconditions.checkNotNull(handler);
+ }
+
+ public boolean isActive() {
+ return channel.isOpen() || channel.isActive();
+ }
+
+ /**
+ * Requests a single chunk from the remote side, from the pre-negotiated streamId.
+ *
+ * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though
+ * some streams may not support this.
+ *
+ * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed
+ * to be returned in the same order that they were requested, assuming only a single
+ * TransportClient is used to fetch the chunks.
+ *
+ * @param streamId Identifier that refers to a stream in the remote StreamManager. This should
+ * be agreed upon by client and server beforehand.
+ * @param chunkIndex 0-based index of the chunk to fetch
+ * @param callback Callback invoked upon successful receipt of chunk, or upon any failure.
+ */
+ public void fetchChunk(
+ long streamId,
+ final int chunkIndex,
+ final ChunkReceivedCallback callback) {
+ final String serverAddr = NettyUtils.getRemoteAddress(channel);
+ final long startTime = System.currentTimeMillis();
+ logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr);
+
+ final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
+ handler.addFetchRequest(streamChunkId, callback);
+
+ channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr,
+ timeTaken);
+ } else {
+ String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId,
+ serverAddr, future.cause());
+ logger.error(errorMsg, future.cause());
+ handler.removeFetchRequest(streamChunkId);
+ callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause()));
+ channel.close();
+ }
+ }
+ });
+ }
+
+ /**
+ * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked
+ * with the server's response or upon any failure.
+ */
+ public void sendRpc(byte[] message, final RpcResponseCallback callback) {
+ final String serverAddr = NettyUtils.getRemoteAddress(channel);
+ final long startTime = System.currentTimeMillis();
+ logger.trace("Sending RPC to {}", serverAddr);
+
+ final long requestId = UUID.randomUUID().getLeastSignificantBits();
+ handler.addRpcRequest(requestId, callback);
+
+ channel.writeAndFlush(new RpcRequest(requestId, message)).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken);
+ } else {
+ String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
+ serverAddr, future.cause());
+ logger.error(errorMsg, future.cause());
+ handler.removeRpcRequest(requestId);
+ callback.onFailure(new RuntimeException(errorMsg, future.cause()));
+ channel.close();
+ }
+ }
+ });
+ }
+
+ @Override
+ public void close() {
+ // close is a local operation and should finish with milliseconds; timeout just to be safe
+ channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
new file mode 100644
index 0000000000..10eb9ef7a0
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.Closeable;
+import java.lang.reflect.Field;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicReference;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.util.internal.PlatformDependent;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.server.TransportChannelHandler;
+import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Factory for creating {@link TransportClient}s by using createClient.
+ *
+ * The factory maintains a connection pool to other hosts and should return the same
+ * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for
+ * all {@link TransportClient}s.
+ */
+public class TransportClientFactory implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
+
+ private final TransportContext context;
+ private final TransportConf conf;
+ private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool;
+
+ private final Class<? extends Channel> socketChannelClass;
+ private final EventLoopGroup workerGroup;
+
+ public TransportClientFactory(TransportContext context) {
+ this.context = context;
+ this.conf = context.getConf();
+ this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>();
+
+ IOMode ioMode = IOMode.valueOf(conf.ioMode());
+ this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
+ // TODO: Make thread pool name configurable.
+ this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client");
+ }
+
+ /**
+ * Create a new BlockFetchingClient connecting to the given remote host / port.
+ *
+ * This blocks until a connection is successfully established.
+ *
+ * Concurrency: This method is safe to call from multiple threads.
+ */
+ public TransportClient createClient(String remoteHost, int remotePort) throws TimeoutException {
+ // Get connection from the connection pool first.
+ // If it is not found or not active, create a new one.
+ final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
+ TransportClient cachedClient = connectionPool.get(address);
+ if (cachedClient != null && cachedClient.isActive()) {
+ return cachedClient;
+ } else if (cachedClient != null) {
+ connectionPool.remove(address, cachedClient); // Remove inactive clients.
+ }
+
+ logger.debug("Creating new connection to " + address);
+
+ Bootstrap bootstrap = new Bootstrap();
+ bootstrap.group(workerGroup)
+ .channel(socketChannelClass)
+ // Disable Nagle's Algorithm since we don't want packets to wait
+ .option(ChannelOption.TCP_NODELAY, true)
+ .option(ChannelOption.SO_KEEPALIVE, true)
+ .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs());
+
+ // Use pooled buffers to reduce temporary buffer allocation
+ bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator());
+
+ final AtomicReference<TransportClient> client = new AtomicReference<TransportClient>();
+
+ bootstrap.handler(new ChannelInitializer<SocketChannel>() {
+ @Override
+ public void initChannel(SocketChannel ch) {
+ TransportChannelHandler clientHandler = context.initializePipeline(ch);
+ client.set(clientHandler.getClient());
+ }
+ });
+
+ // Connect to the remote server
+ ChannelFuture cf = bootstrap.connect(address);
+ if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
+ throw new TimeoutException(
+ String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
+ } else if (cf.cause() != null) {
+ throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
+ }
+
+ // Successful connection
+ assert client.get() != null : "Channel future completed successfully with null client";
+ TransportClient oldClient = connectionPool.putIfAbsent(address, client.get());
+ if (oldClient == null) {
+ return client.get();
+ } else {
+ logger.debug("Two clients were created concurrently, second one will be disposed.");
+ client.get().close();
+ return oldClient;
+ }
+ }
+
+ /** Close all connections in the connection pool, and shutdown the worker thread pool. */
+ @Override
+ public void close() {
+ for (TransportClient client : connectionPool.values()) {
+ try {
+ client.close();
+ } catch (RuntimeException e) {
+ logger.warn("Ignoring exception during close", e);
+ }
+ }
+ connectionPool.clear();
+
+ if (workerGroup != null) {
+ workerGroup.shutdownGracefully();
+ }
+ }
+
+ /**
+ * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches
+ * are disabled because the ByteBufs are allocated by the event loop thread, but released by the
+ * executor thread rather than the event loop thread. Those thread-local caches actually delay
+ * the recycling of buffers, leading to larger memory usage.
+ */
+ private PooledByteBufAllocator createPooledByteBufAllocator() {
+ return new PooledByteBufAllocator(
+ PlatformDependent.directBufferPreferred(),
+ getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
+ getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
+ getPrivateStaticField("DEFAULT_PAGE_SIZE"),
+ getPrivateStaticField("DEFAULT_MAX_ORDER"),
+ 0, // tinyCacheSize
+ 0, // smallCacheSize
+ 0 // normalCacheSize
+ );
+ }
+
+ /** Used to get defaults from Netty's private static fields. */
+ private int getPrivateStaticField(String name) {
+ try {
+ Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name);
+ f.setAccessible(true);
+ return f.getInt(null);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
new file mode 100644
index 0000000000..d8965590b3
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import com.google.common.annotations.VisibleForTesting;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.ResponseMessage;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.server.MessageHandler;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * Handler that processes server responses, in response to requests issued from a
+ * [[TransportClient]]. It works by tracking the list of outstanding requests (and their callbacks).
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
+ private final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class);
+
+ private final Channel channel;
+
+ private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches;
+
+ private final Map<Long, RpcResponseCallback> outstandingRpcs;
+
+ public TransportResponseHandler(Channel channel) {
+ this.channel = channel;
+ this.outstandingFetches = new ConcurrentHashMap<StreamChunkId, ChunkReceivedCallback>();
+ this.outstandingRpcs = new ConcurrentHashMap<Long, RpcResponseCallback>();
+ }
+
+ public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
+ outstandingFetches.put(streamChunkId, callback);
+ }
+
+ public void removeFetchRequest(StreamChunkId streamChunkId) {
+ outstandingFetches.remove(streamChunkId);
+ }
+
+ public void addRpcRequest(long requestId, RpcResponseCallback callback) {
+ outstandingRpcs.put(requestId, callback);
+ }
+
+ public void removeRpcRequest(long requestId) {
+ outstandingRpcs.remove(requestId);
+ }
+
+ /**
+ * Fire the failure callback for all outstanding requests. This is called when we have an
+ * uncaught exception or pre-mature connection termination.
+ */
+ private void failOutstandingRequests(Throwable cause) {
+ for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : outstandingFetches.entrySet()) {
+ entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
+ }
+ for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) {
+ entry.getValue().onFailure(cause);
+ }
+
+ // It's OK if new fetches appear, as they will fail immediately.
+ outstandingFetches.clear();
+ outstandingRpcs.clear();
+ }
+
+ @Override
+ public void channelUnregistered() {
+ if (numOutstandingRequests() > 0) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ logger.error("Still have {} requests outstanding when connection from {} is closed",
+ numOutstandingRequests(), remoteAddress);
+ failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed"));
+ }
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) {
+ if (numOutstandingRequests() > 0) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ logger.error("Still have {} requests outstanding when connection from {} is closed",
+ numOutstandingRequests(), remoteAddress);
+ failOutstandingRequests(cause);
+ }
+ }
+
+ @Override
+ public void handle(ResponseMessage message) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ if (message instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
+ ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
+ if (listener == null) {
+ logger.warn("Ignoring response for block {} from {} since it is not outstanding",
+ resp.streamChunkId, remoteAddress);
+ resp.buffer.release();
+ } else {
+ outstandingFetches.remove(resp.streamChunkId);
+ listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer);
+ resp.buffer.release();
+ }
+ } else if (message instanceof ChunkFetchFailure) {
+ ChunkFetchFailure resp = (ChunkFetchFailure) message;
+ ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
+ if (listener == null) {
+ logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding",
+ resp.streamChunkId, remoteAddress, resp.errorString);
+ } else {
+ outstandingFetches.remove(resp.streamChunkId);
+ listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException(
+ "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString));
+ }
+ } else if (message instanceof RpcResponse) {
+ RpcResponse resp = (RpcResponse) message;
+ RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+ if (listener == null) {
+ logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
+ resp.requestId, remoteAddress, resp.response.length);
+ } else {
+ outstandingRpcs.remove(resp.requestId);
+ listener.onSuccess(resp.response);
+ }
+ } else if (message instanceof RpcFailure) {
+ RpcFailure resp = (RpcFailure) message;
+ RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+ if (listener == null) {
+ logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
+ resp.requestId, remoteAddress, resp.errorString);
+ } else {
+ outstandingRpcs.remove(resp.requestId);
+ listener.onFailure(new RuntimeException(resp.errorString));
+ }
+ } else {
+ throw new IllegalStateException("Unknown response type: " + message.type());
+ }
+ }
+
+ /** Returns total number of outstanding requests (fetch requests + rpcs) */
+ @VisibleForTesting
+ public int numOutstandingRequests() {
+ return outstandingFetches.size() + outstandingRpcs.size();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
new file mode 100644
index 0000000000..152af98ced
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk.
+ */
+public final class ChunkFetchFailure implements ResponseMessage {
+ public final StreamChunkId streamChunkId;
+ public final String errorString;
+
+ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) {
+ this.streamChunkId = streamChunkId;
+ this.errorString = errorString;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchFailure; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
+ buf.writeInt(errorBytes.length);
+ buf.writeBytes(errorBytes);
+ }
+
+ public static ChunkFetchFailure decode(ByteBuf buf) {
+ StreamChunkId streamChunkId = StreamChunkId.decode(buf);
+ int numErrorStringBytes = buf.readInt();
+ byte[] errorBytes = new byte[numErrorStringBytes];
+ buf.readBytes(errorBytes);
+ return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchFailure) {
+ ChunkFetchFailure o = (ChunkFetchFailure) other;
+ return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .add("errorString", errorString)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
new file mode 100644
index 0000000000..980947cf13
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single
+ * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
+ */
+public final class ChunkFetchRequest implements RequestMessage {
+ public final StreamChunkId streamChunkId;
+
+ public ChunkFetchRequest(StreamChunkId streamChunkId) {
+ this.streamChunkId = streamChunkId;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchRequest; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength();
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ }
+
+ public static ChunkFetchRequest decode(ByteBuf buf) {
+ return new ChunkFetchRequest(StreamChunkId.decode(buf));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchRequest) {
+ ChunkFetchRequest o = (ChunkFetchRequest) other;
+ return streamChunkId.equals(o.streamChunkId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
new file mode 100644
index 0000000000..ff4936470c
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched.
+ *
+ * Note that the server-side encoding of this messages does NOT include the buffer itself, as this
+ * may be written by Netty in a more efficient manner (i.e., zero-copy write).
+ * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer.
+ */
+public final class ChunkFetchSuccess implements ResponseMessage {
+ public final StreamChunkId streamChunkId;
+ public final ManagedBuffer buffer;
+
+ public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
+ this.streamChunkId = streamChunkId;
+ this.buffer = buffer;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchSuccess; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength();
+ }
+
+ /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ }
+
+ /** Decoding uses the given ByteBuf as our data, and will retain() it. */
+ public static ChunkFetchSuccess decode(ByteBuf buf) {
+ StreamChunkId streamChunkId = StreamChunkId.decode(buf);
+ buf.retain();
+ NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate());
+ return new ChunkFetchSuccess(streamChunkId, managedBuf);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess o = (ChunkFetchSuccess) other;
+ return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .add("buffer", buffer)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java
new file mode 100644
index 0000000000..b4e299471b
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are
+ * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length.
+ *
+ * Encodable objects should provide a static "decode(ByteBuf)" method which is invoked by
+ * {@link MessageDecoder}. During decoding, if the object uses the ByteBuf as its data (rather than
+ * just copying data from it), then you must retain() the ByteBuf.
+ *
+ * Additionally, when adding a new Encodable Message, add it to {@link Message.Type}.
+ */
+public interface Encodable {
+ /** Number of bytes of the encoded form of this object. */
+ int encodedLength();
+
+ /**
+ * Serializes this object by writing into the given ByteBuf.
+ * This method must write exactly encodedLength() bytes.
+ */
+ void encode(ByteBuf buf);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
new file mode 100644
index 0000000000..d568370125
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+/** An on-the-wire transmittable message. */
+public interface Message extends Encodable {
+ /** Used to identify this request type. */
+ Type type();
+
+ /** Preceding every serialized Message is its type, which allows us to deserialize it. */
+ public static enum Type implements Encodable {
+ ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
+ RpcRequest(3), RpcResponse(4), RpcFailure(5);
+
+ private final byte id;
+
+ private Type(int id) {
+ assert id < 128 : "Cannot have more than 128 message types";
+ this.id = (byte) id;
+ }
+
+ public byte id() { return id; }
+
+ @Override public int encodedLength() { return 1; }
+
+ @Override public void encode(ByteBuf buf) { buf.writeByte(id); }
+
+ public static Type decode(ByteBuf buf) {
+ byte id = buf.readByte();
+ switch (id) {
+ case 0: return ChunkFetchRequest;
+ case 1: return ChunkFetchSuccess;
+ case 2: return ChunkFetchFailure;
+ case 3: return RpcRequest;
+ case 4: return RpcResponse;
+ case 5: return RpcFailure;
+ default: throw new IllegalArgumentException("Unknown message type: " + id);
+ }
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
new file mode 100644
index 0000000000..81f8d7f963
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageDecoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Decoder used by the client side to encode server-to-client responses.
+ * This encoder is stateless so it is safe to be shared by multiple threads.
+ */
+@ChannelHandler.Sharable
+public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
+
+ private final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);
+ @Override
+ public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
+ Message.Type msgType = Message.Type.decode(in);
+ Message decoded = decode(msgType, in);
+ assert decoded.type() == msgType;
+ logger.trace("Received message " + msgType + ": " + decoded);
+ out.add(decoded);
+ }
+
+ private Message decode(Message.Type msgType, ByteBuf in) {
+ switch (msgType) {
+ case ChunkFetchRequest:
+ return ChunkFetchRequest.decode(in);
+
+ case ChunkFetchSuccess:
+ return ChunkFetchSuccess.decode(in);
+
+ case ChunkFetchFailure:
+ return ChunkFetchFailure.decode(in);
+
+ case RpcRequest:
+ return RpcRequest.decode(in);
+
+ case RpcResponse:
+ return RpcResponse.decode(in);
+
+ case RpcFailure:
+ return RpcFailure.decode(in);
+
+ default:
+ throw new IllegalArgumentException("Unexpected message type: " + msgType);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
new file mode 100644
index 0000000000..4cb8becc3e
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageEncoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Encoder used by the server side to encode server-to-client responses.
+ * This encoder is stateless so it is safe to be shared by multiple threads.
+ */
+@ChannelHandler.Sharable
+public final class MessageEncoder extends MessageToMessageEncoder<Message> {
+
+ private final Logger logger = LoggerFactory.getLogger(MessageEncoder.class);
+
+ /***
+ * Encodes a Message by invoking its encode() method. For non-data messages, we will add one
+ * ByteBuf to 'out' containing the total frame length, the message type, and the message itself.
+ * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the
+ * data to 'out', in order to enable zero-copy transfer.
+ */
+ @Override
+ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) {
+ Object body = null;
+ long bodyLength = 0;
+
+ // Only ChunkFetchSuccesses have data besides the header.
+ // The body is used in order to enable zero-copy transfer for the payload.
+ if (in instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess resp = (ChunkFetchSuccess) in;
+ try {
+ bodyLength = resp.buffer.size();
+ body = resp.buffer.convertToNetty();
+ } catch (Exception e) {
+ // Re-encode this message as BlockFetchFailure.
+ logger.error(String.format("Error opening block %s for client %s",
+ resp.streamChunkId, ctx.channel().remoteAddress()), e);
+ encode(ctx, new ChunkFetchFailure(resp.streamChunkId, e.getMessage()), out);
+ return;
+ }
+ }
+
+ Message.Type msgType = in.type();
+ // All messages have the frame length, message type, and message itself.
+ int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
+ long frameLength = headerLength + bodyLength;
+ ByteBuf header = ctx.alloc().buffer(headerLength);
+ header.writeLong(frameLength);
+ msgType.encode(header);
+ in.encode(header);
+ assert header.writableBytes() == 0;
+
+ out.add(header);
+ if (body != null && bodyLength > 0) {
+ out.add(body);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java
new file mode 100644
index 0000000000..31b15bb17a
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import org.apache.spark.network.protocol.Message;
+
+/** Messages from the client to the server. */
+public interface RequestMessage extends Message {
+ // token interface
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java
new file mode 100644
index 0000000000..6edffd11cf
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import org.apache.spark.network.protocol.Message;
+
+/** Messages from the server to the client. */
+public interface ResponseMessage extends Message {
+ // token interface
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
new file mode 100644
index 0000000000..e239d4ffbd
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/** Response to {@link RpcRequest} for a failed RPC. */
+public final class RpcFailure implements ResponseMessage {
+ public final long requestId;
+ public final String errorString;
+
+ public RpcFailure(long requestId, String errorString) {
+ this.requestId = requestId;
+ this.errorString = errorString;
+ }
+
+ @Override
+ public Type type() { return Type.RpcFailure; }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
+ buf.writeInt(errorBytes.length);
+ buf.writeBytes(errorBytes);
+ }
+
+ public static RpcFailure decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ int numErrorStringBytes = buf.readInt();
+ byte[] errorBytes = new byte[numErrorStringBytes];
+ buf.readBytes(errorBytes);
+ return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcFailure) {
+ RpcFailure o = (RpcFailure) other;
+ return requestId == o.requestId && errorString.equals(o.errorString);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("errorString", errorString)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
new file mode 100644
index 0000000000..099e934ae0
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}.
+ * This will correspond to a single
+ * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
+ */
+public final class RpcRequest implements RequestMessage {
+ /** Used to link an RPC request with its response. */
+ public final long requestId;
+
+ /** Serialized message to send to remote RpcHandler. */
+ public final byte[] message;
+
+ public RpcRequest(long requestId, byte[] message) {
+ this.requestId = requestId;
+ this.message = message;
+ }
+
+ @Override
+ public Type type() { return Type.RpcRequest; }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4 + message.length;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ buf.writeInt(message.length);
+ buf.writeBytes(message);
+ }
+
+ public static RpcRequest decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ int messageLen = buf.readInt();
+ byte[] message = new byte[messageLen];
+ buf.readBytes(message);
+ return new RpcRequest(requestId, message);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcRequest) {
+ RpcRequest o = (RpcRequest) other;
+ return requestId == o.requestId && Arrays.equals(message, o.message);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("message", message)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
new file mode 100644
index 0000000000..ed47947832
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/** Response to {@link RpcRequest} for a successful RPC. */
+public final class RpcResponse implements ResponseMessage {
+ public final long requestId;
+ public final byte[] response;
+
+ public RpcResponse(long requestId, byte[] response) {
+ this.requestId = requestId;
+ this.response = response;
+ }
+
+ @Override
+ public Type type() { return Type.RpcResponse; }
+
+ @Override
+ public int encodedLength() { return 8 + 4 + response.length; }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ buf.writeInt(response.length);
+ buf.writeBytes(response);
+ }
+
+ public static RpcResponse decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ int responseLen = buf.readInt();
+ byte[] response = new byte[responseLen];
+ buf.readBytes(response);
+ return new RpcResponse(requestId, response);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcResponse) {
+ RpcResponse o = (RpcResponse) other;
+ return requestId == o.requestId && Arrays.equals(response, o.response);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("response", response)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java
new file mode 100644
index 0000000000..d46a263884
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+* Encapsulates a request for a particular chunk of a stream.
+*/
+public final class StreamChunkId implements Encodable {
+ public final long streamId;
+ public final int chunkIndex;
+
+ public StreamChunkId(long streamId, int chunkIndex) {
+ this.streamId = streamId;
+ this.chunkIndex = chunkIndex;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4;
+ }
+
+ public void encode(ByteBuf buffer) {
+ buffer.writeLong(streamId);
+ buffer.writeInt(chunkIndex);
+ }
+
+ public static StreamChunkId decode(ByteBuf buffer) {
+ assert buffer.readableBytes() >= 8 + 4;
+ long streamId = buffer.readLong();
+ int chunkIndex = buffer.readInt();
+ return new StreamChunkId(streamId, chunkIndex);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId, chunkIndex);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamChunkId) {
+ StreamChunkId o = (StreamChunkId) other;
+ return streamId == o.streamId && chunkIndex == o.chunkIndex;
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("chunkIndex", chunkIndex)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java
new file mode 100644
index 0000000000..9688705569
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually
+ * fetched as chunks by the client.
+ */
+public class DefaultStreamManager extends StreamManager {
+ private final Logger logger = LoggerFactory.getLogger(DefaultStreamManager.class);
+
+ private final AtomicLong nextStreamId;
+ private final Map<Long, StreamState> streams;
+
+ /** State of a single stream. */
+ private static class StreamState {
+ final Iterator<ManagedBuffer> buffers;
+
+ // Used to keep track of the index of the buffer that the user has retrieved, just to ensure
+ // that the caller only requests each chunk one at a time, in order.
+ int curChunk = 0;
+
+ StreamState(Iterator<ManagedBuffer> buffers) {
+ this.buffers = buffers;
+ }
+ }
+
+ public DefaultStreamManager() {
+ // For debugging purposes, start with a random stream id to help identifying different streams.
+ // This does not need to be globally unique, only unique to this class.
+ nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
+ streams = new ConcurrentHashMap<Long, StreamState>();
+ }
+
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ StreamState state = streams.get(streamId);
+ if (chunkIndex != state.curChunk) {
+ throw new IllegalStateException(String.format(
+ "Received out-of-order chunk index %s (expected %s)", chunkIndex, state.curChunk));
+ } else if (!state.buffers.hasNext()) {
+ throw new IllegalStateException(String.format(
+ "Requested chunk index beyond end %s", chunkIndex));
+ }
+ state.curChunk += 1;
+ ManagedBuffer nextChunk = state.buffers.next();
+
+ if (!state.buffers.hasNext()) {
+ logger.trace("Removing stream id {}", streamId);
+ streams.remove(streamId);
+ }
+
+ return nextChunk;
+ }
+
+ @Override
+ public void connectionTerminated(long streamId) {
+ // Release all remaining buffers.
+ StreamState state = streams.remove(streamId);
+ if (state != null && state.buffers != null) {
+ while (state.buffers.hasNext()) {
+ state.buffers.next().release();
+ }
+ }
+ }
+
+ /**
+ * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
+ * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a
+ * client connection is closed before the iterator is fully drained, then the remaining buffers
+ * will all be release()'d.
+ */
+ public long registerStream(Iterator<ManagedBuffer> buffers) {
+ long myStreamId = nextStreamId.getAndIncrement();
+ streams.put(myStreamId, new StreamState(buffers));
+ return myStreamId;
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
new file mode 100644
index 0000000000..b80c15106e
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import org.apache.spark.network.protocol.Message;
+
+/**
+ * Handles either request or response messages coming off of Netty. A MessageHandler instance
+ * is associated with a single Netty Channel (though it may have multiple clients on the same
+ * Channel.)
+ */
+public abstract class MessageHandler<T extends Message> {
+ /** Handles the receipt of a single message. */
+ public abstract void handle(T message);
+
+ /** Invoked when an exception was caught on the Channel. */
+ public abstract void exceptionCaught(Throwable cause);
+
+ /** Invoked when the channel this MessageHandler is on has been unregistered. */
+ public abstract void channelUnregistered();
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
new file mode 100644
index 0000000000..f54a696b8f
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+
+/**
+ * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
+ */
+public interface RpcHandler {
+ /**
+ * Receive a single RPC message. Any exception thrown while in this method will be sent back to
+ * the client in string form as a standard RPC failure.
+ *
+ * @param client A channel client which enables the handler to make requests back to the sender
+ * of this RPC.
+ * @param message The serialized bytes of the RPC.
+ * @param callback Callback which should be invoked exactly once upon success or failure of the
+ * RPC.
+ */
+ void receive(TransportClient client, byte[] message, RpcResponseCallback callback);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
new file mode 100644
index 0000000000..5a9a14a180
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * The StreamManager is used to fetch individual chunks from a stream. This is used in
+ * {@link TransportRequestHandler} in order to respond to fetchChunk() requests. Creation of the
+ * stream is outside the scope of the transport layer, but a given stream is guaranteed to be read
+ * by only one client connection, meaning that getChunk() for a particular stream will be called
+ * serially and that once the connection associated with the stream is closed, that stream will
+ * never be used again.
+ */
+public abstract class StreamManager {
+ /**
+ * Called in response to a fetchChunk() request. The returned buffer will be passed as-is to the
+ * client. A single stream will be associated with a single TCP connection, so this method
+ * will not be called in parallel for a particular stream.
+ *
+ * Chunks may be requested in any order, and requests may be repeated, but it is not required
+ * that implementations support this behavior.
+ *
+ * The returned ManagedBuffer will be release()'d after being written to the network.
+ *
+ * @param streamId id of a stream that has been previously registered with the StreamManager.
+ * @param chunkIndex 0-indexed chunk of the stream that's requested
+ */
+ public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);
+
+ /**
+ * Indicates that the TCP connection that was tied to the given stream has been terminated. After
+ * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned
+ * up.
+ */
+ public void connectionTerminated(long streamId) { }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
new file mode 100644
index 0000000000..e491367fa4
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.SimpleChannelInboundHandler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.Message;
+import org.apache.spark.network.protocol.RequestMessage;
+import org.apache.spark.network.protocol.ResponseMessage;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * The single Transport-level Channel handler which is used for delegating requests to the
+ * {@link TransportRequestHandler} and responses to the {@link TransportResponseHandler}.
+ *
+ * All channels created in the transport layer are bidirectional. When the Client initiates a Netty
+ * Channel with a RequestMessage (which gets handled by the Server's RequestHandler), the Server
+ * will produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server
+ * also gets a handle on the same Channel, so it may then begin to send RequestMessages to the
+ * Client.
+ * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler,
+ * for the Client's responses to the Server's requests.
+ */
+public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {
+ private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);
+
+ private final TransportClient client;
+ private final TransportResponseHandler responseHandler;
+ private final TransportRequestHandler requestHandler;
+
+ public TransportChannelHandler(
+ TransportClient client,
+ TransportResponseHandler responseHandler,
+ TransportRequestHandler requestHandler) {
+ this.client = client;
+ this.responseHandler = responseHandler;
+ this.requestHandler = requestHandler;
+ }
+
+ public TransportClient getClient() {
+ return client;
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
+ logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()),
+ cause);
+ requestHandler.exceptionCaught(cause);
+ responseHandler.exceptionCaught(cause);
+ ctx.close();
+ }
+
+ @Override
+ public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
+ try {
+ requestHandler.channelUnregistered();
+ } catch (RuntimeException e) {
+ logger.error("Exception from request handler while unregistering channel", e);
+ }
+ try {
+ responseHandler.channelUnregistered();
+ } catch (RuntimeException e) {
+ logger.error("Exception from response handler while unregistering channel", e);
+ }
+ super.channelUnregistered(ctx);
+ }
+
+ @Override
+ public void channelRead0(ChannelHandlerContext ctx, Message request) {
+ if (request instanceof RequestMessage) {
+ requestHandler.handle((RequestMessage) request);
+ } else {
+ responseHandler.handle((ResponseMessage) request);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
new file mode 100644
index 0000000000..352f865935
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.util.Set;
+
+import com.google.common.base.Throwables;
+import com.google.common.collect.Sets;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.RequestMessage;
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * A handler that processes requests from clients and writes chunk data back. Each handler is
+ * attached to a single Netty channel, and keeps track of which streams have been fetched via this
+ * channel, in order to clean them up if the channel is terminated (see #channelUnregistered).
+ *
+ * The messages should have been processed by the pipeline setup by {@link TransportServer}.
+ */
+public class TransportRequestHandler extends MessageHandler<RequestMessage> {
+ private final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class);
+
+ /** The Netty channel that this handler is associated with. */
+ private final Channel channel;
+
+ /** Client on the same channel allowing us to talk back to the requester. */
+ private final TransportClient reverseClient;
+
+ /** Returns each chunk part of a stream. */
+ private final StreamManager streamManager;
+
+ /** Handles all RPC messages. */
+ private final RpcHandler rpcHandler;
+
+ /** List of all stream ids that have been read on this handler, used for cleanup. */
+ private final Set<Long> streamIds;
+
+ public TransportRequestHandler(
+ Channel channel,
+ TransportClient reverseClient,
+ StreamManager streamManager,
+ RpcHandler rpcHandler) {
+ this.channel = channel;
+ this.reverseClient = reverseClient;
+ this.streamManager = streamManager;
+ this.rpcHandler = rpcHandler;
+ this.streamIds = Sets.newHashSet();
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) {
+ }
+
+ @Override
+ public void channelUnregistered() {
+ // Inform the StreamManager that these streams will no longer be read from.
+ for (long streamId : streamIds) {
+ streamManager.connectionTerminated(streamId);
+ }
+ }
+
+ @Override
+ public void handle(RequestMessage request) {
+ if (request instanceof ChunkFetchRequest) {
+ processFetchRequest((ChunkFetchRequest) request);
+ } else if (request instanceof RpcRequest) {
+ processRpcRequest((RpcRequest) request);
+ } else {
+ throw new IllegalArgumentException("Unknown request type: " + request);
+ }
+ }
+
+ private void processFetchRequest(final ChunkFetchRequest req) {
+ final String client = NettyUtils.getRemoteAddress(channel);
+ streamIds.add(req.streamChunkId.streamId);
+
+ logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId);
+
+ ManagedBuffer buf;
+ try {
+ buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
+ } catch (Exception e) {
+ logger.error(String.format(
+ "Error opening block %s for request from %s", req.streamChunkId, client), e);
+ respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e)));
+ return;
+ }
+
+ respond(new ChunkFetchSuccess(req.streamChunkId, buf));
+ }
+
+ private void processRpcRequest(final RpcRequest req) {
+ try {
+ rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() {
+ @Override
+ public void onSuccess(byte[] response) {
+ respond(new RpcResponse(req.requestId, response));
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
+ }
+ });
+ } catch (Exception e) {
+ logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
+ respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
+ }
+ }
+
+ /**
+ * Responds to a single message with some Encodable object. If a failure occurs while sending,
+ * it will be logged and the channel closed.
+ */
+ private void respond(final Encodable result) {
+ final String remoteAddress = channel.remoteAddress().toString();
+ channel.writeAndFlush(result).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ logger.trace(String.format("Sent result %s to client %s", result, remoteAddress));
+ } else {
+ logger.error(String.format("Error sending result %s to %s; closing connection",
+ result, remoteAddress), future.cause());
+ channel.close();
+ }
+ }
+ }
+ );
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
new file mode 100644
index 0000000000..243070750d
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.io.Closeable;
+import java.net.InetSocketAddress;
+import java.util.concurrent.TimeUnit;
+
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Server for the efficient, low-level streaming service.
+ */
+public class TransportServer implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(TransportServer.class);
+
+ private final TransportContext context;
+ private final TransportConf conf;
+
+ private ServerBootstrap bootstrap;
+ private ChannelFuture channelFuture;
+ private int port = -1;
+
+ public TransportServer(TransportContext context) {
+ this.context = context;
+ this.conf = context.getConf();
+
+ init();
+ }
+
+ public int getPort() {
+ if (port == -1) {
+ throw new IllegalStateException("Server not initialized");
+ }
+ return port;
+ }
+
+ private void init() {
+
+ IOMode ioMode = IOMode.valueOf(conf.ioMode());
+ EventLoopGroup bossGroup =
+ NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server");
+ EventLoopGroup workerGroup = bossGroup;
+
+ bootstrap = new ServerBootstrap()
+ .group(bossGroup, workerGroup)
+ .channel(NettyUtils.getServerChannelClass(ioMode))
+ .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+ .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT);
+
+ if (conf.backLog() > 0) {
+ bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());
+ }
+
+ if (conf.receiveBuf() > 0) {
+ bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf());
+ }
+
+ if (conf.sendBuf() > 0) {
+ bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf());
+ }
+
+ bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
+ @Override
+ protected void initChannel(SocketChannel ch) throws Exception {
+ context.initializePipeline(ch);
+ }
+ });
+
+ channelFuture = bootstrap.bind(new InetSocketAddress(conf.serverPort()));
+ channelFuture.syncUninterruptibly();
+
+ port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
+ logger.debug("Shuffle server started on port :" + port);
+ }
+
+ @Override
+ public void close() {
+ if (channelFuture != null) {
+ // close is a local operation and should finish with milliseconds; timeout just to be safe
+ channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS);
+ channelFuture = null;
+ }
+ if (bootstrap != null && bootstrap.group() != null) {
+ bootstrap.group().shutdownGracefully();
+ }
+ if (bootstrap != null && bootstrap.childGroup() != null) {
+ bootstrap.childGroup().shutdownGracefully();
+ }
+ bootstrap = null;
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java
new file mode 100644
index 0000000000..d944d9da1c
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.util.NoSuchElementException;
+
+/**
+ * Provides a mechanism for constructing a {@link TransportConf} using some sort of configuration.
+ */
+public abstract class ConfigProvider {
+ /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */
+ public abstract String get(String name);
+
+ public String get(String name, String defaultValue) {
+ try {
+ return get(name);
+ } catch (NoSuchElementException e) {
+ return defaultValue;
+ }
+ }
+
+ public int getInt(String name, int defaultValue) {
+ return Integer.parseInt(get(name, Integer.toString(defaultValue)));
+ }
+
+ public long getLong(String name, long defaultValue) {
+ return Long.parseLong(get(name, Long.toString(defaultValue)));
+ }
+
+ public double getDouble(String name, double defaultValue) {
+ return Double.parseDouble(get(name, Double.toString(defaultValue)));
+ }
+
+ public boolean getBoolean(String name, boolean defaultValue) {
+ return Boolean.parseBoolean(get(name, Boolean.toString(defaultValue)));
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java
new file mode 100644
index 0000000000..6b208d95bb
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+/**
+ * Selector for which form of low-level IO we should use.
+ * NIO is always available, while EPOLL is only available on Linux.
+ * AUTO is used to select EPOLL if it's available, or NIO otherwise.
+ */
+public enum IOMode {
+ NIO, EPOLL
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
new file mode 100644
index 0000000000..32ba3f5b07
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+import com.google.common.io.Closeables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class JavaUtils {
+ private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class);
+
+ /** Closes the given object, ignoring IOExceptions. */
+ public static void closeQuietly(Closeable closeable) {
+ try {
+ closeable.close();
+ } catch (IOException e) {
+ logger.error("IOException should not have been thrown.", e);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
new file mode 100644
index 0000000000..b187234119
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.util.concurrent.ThreadFactory;
+
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import io.netty.channel.Channel;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.ServerChannel;
+import io.netty.channel.epoll.Epoll;
+import io.netty.channel.epoll.EpollEventLoopGroup;
+import io.netty.channel.epoll.EpollServerSocketChannel;
+import io.netty.channel.epoll.EpollSocketChannel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.handler.codec.ByteToMessageDecoder;
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
+
+/**
+ * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO.
+ */
+public class NettyUtils {
+ /** Creates a Netty EventLoopGroup based on the IOMode. */
+ public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
+
+ ThreadFactory threadFactory = new ThreadFactoryBuilder()
+ .setDaemon(true)
+ .setNameFormat(threadPrefix + "-%d")
+ .build();
+
+ switch (mode) {
+ case NIO:
+ return new NioEventLoopGroup(numThreads, threadFactory);
+ case EPOLL:
+ return new EpollEventLoopGroup(numThreads, threadFactory);
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /** Returns the correct (client) SocketChannel class based on IOMode. */
+ public static Class<? extends Channel> getClientChannelClass(IOMode mode) {
+ switch (mode) {
+ case NIO:
+ return NioSocketChannel.class;
+ case EPOLL:
+ return EpollSocketChannel.class;
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /** Returns the correct ServerSocketChannel class based on IOMode. */
+ public static Class<? extends ServerChannel> getServerChannelClass(IOMode mode) {
+ switch (mode) {
+ case NIO:
+ return NioServerSocketChannel.class;
+ case EPOLL:
+ return EpollServerSocketChannel.class;
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /**
+ * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame.
+ * This is used before all decoders.
+ */
+ public static ByteToMessageDecoder createFrameDecoder() {
+ // maxFrameLength = 2G
+ // lengthFieldOffset = 0
+ // lengthFieldLength = 8
+ // lengthAdjustment = -8, i.e. exclude the 8 byte length itself
+ // initialBytesToStrip = 8, i.e. strip out the length field itself
+ return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8);
+ }
+
+ /** Returns the remote address on the channel or "<remote address>" if none exists. */
+ public static String getRemoteAddress(Channel channel) {
+ if (channel != null && channel.remoteAddress() != null) {
+ return channel.remoteAddress().toString();
+ }
+ return "<unknown remote>";
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
new file mode 100644
index 0000000000..80f65d9803
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+/**
+ * A central location that tracks all the settings we expose to users.
+ */
+public class TransportConf {
+ private final ConfigProvider conf;
+
+ public TransportConf(ConfigProvider conf) {
+ this.conf = conf;
+ }
+
+ /** Port the server listens on. Default to a random port. */
+ public int serverPort() { return conf.getInt("spark.shuffle.io.port", 0); }
+
+ /** IO mode: nio or epoll */
+ public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); }
+
+ /** Connect timeout in secs. Default 120 secs. */
+ public int connectionTimeoutMs() {
+ return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
+ }
+
+ /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */
+ public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); }
+
+ /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */
+ public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); }
+
+ /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */
+ public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); }
+
+ /**
+ * Receive buffer size (SO_RCVBUF).
+ * Note: the optimal size for receive buffer and send buffer should be
+ * latency * network_bandwidth.
+ * Assuming latency = 1ms, network_bandwidth = 10Gbps
+ * buffer size should be ~ 1.25MB
+ */
+ public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); }
+
+ /** Send buffer size (SO_SNDBUF). */
+ public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
new file mode 100644
index 0000000000..738dca9b6a
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.io.File;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.TransportConf;
+
+public class ChunkFetchIntegrationSuite {
+ static final long STREAM_ID = 1;
+ static final int BUFFER_CHUNK_INDEX = 0;
+ static final int FILE_CHUNK_INDEX = 1;
+
+ static TransportServer server;
+ static TransportClientFactory clientFactory;
+ static StreamManager streamManager;
+ static File testFile;
+
+ static ManagedBuffer bufferChunk;
+ static ManagedBuffer fileChunk;
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ int bufSize = 100000;
+ final ByteBuffer buf = ByteBuffer.allocate(bufSize);
+ for (int i = 0; i < bufSize; i ++) {
+ buf.put((byte) i);
+ }
+ buf.flip();
+ bufferChunk = new NioManagedBuffer(buf);
+
+ testFile = File.createTempFile("shuffle-test-file", "txt");
+ testFile.deleteOnExit();
+ RandomAccessFile fp = new RandomAccessFile(testFile, "rw");
+ byte[] fileContent = new byte[1024];
+ new Random().nextBytes(fileContent);
+ fp.write(fileContent);
+ fp.close();
+ fileChunk = new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25);
+
+ TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+ streamManager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ assertEquals(STREAM_ID, streamId);
+ if (chunkIndex == BUFFER_CHUNK_INDEX) {
+ return new NioManagedBuffer(buf);
+ } else if (chunkIndex == FILE_CHUNK_INDEX) {
+ return new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25);
+ } else {
+ throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex);
+ }
+ }
+ };
+ TransportContext context = new TransportContext(conf, streamManager, new NoOpRpcHandler());
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ server.close();
+ clientFactory.close();
+ testFile.delete();
+ }
+
+ class FetchResult {
+ public Set<Integer> successChunks;
+ public Set<Integer> failedChunks;
+ public List<ManagedBuffer> buffers;
+
+ public void releaseBuffers() {
+ for (ManagedBuffer buffer : buffers) {
+ buffer.release();
+ }
+ }
+ }
+
+ private FetchResult fetchChunks(List<Integer> chunkIndices) throws Exception {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ final Semaphore sem = new Semaphore(0);
+
+ final FetchResult res = new FetchResult();
+ res.successChunks = Collections.synchronizedSet(new HashSet<Integer>());
+ res.failedChunks = Collections.synchronizedSet(new HashSet<Integer>());
+ res.buffers = Collections.synchronizedList(new LinkedList<ManagedBuffer>());
+
+ ChunkReceivedCallback callback = new ChunkReceivedCallback() {
+ @Override
+ public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ buffer.retain();
+ res.successChunks.add(chunkIndex);
+ res.buffers.add(buffer);
+ sem.release();
+ }
+
+ @Override
+ public void onFailure(int chunkIndex, Throwable e) {
+ res.failedChunks.add(chunkIndex);
+ sem.release();
+ }
+ };
+
+ for (int chunkIndex : chunkIndices) {
+ client.fetchChunk(STREAM_ID, chunkIndex, callback);
+ }
+ if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server");
+ }
+ client.close();
+ return res;
+ }
+
+ @Test
+ public void fetchBufferChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchFileChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchNonExistentChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(12345));
+ assertTrue(res.successChunks.isEmpty());
+ assertEquals(res.failedChunks, Sets.newHashSet(12345));
+ assertTrue(res.buffers.isEmpty());
+ }
+
+ @Test
+ public void fetchBothChunks() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchChunkAndNonExistent() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX));
+ assertEquals(res.failedChunks, Sets.newHashSet(12345));
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk));
+ res.releaseBuffers();
+ }
+
+ private void assertBufferListsEqual(List<ManagedBuffer> list0, List<ManagedBuffer> list1)
+ throws Exception {
+ assertEquals(list0.size(), list1.size());
+ for (int i = 0; i < list0.size(); i ++) {
+ assertBuffersEqual(list0.get(i), list1.get(i));
+ }
+ }
+
+ private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception {
+ ByteBuffer nio0 = buffer0.nioByteBuffer();
+ ByteBuffer nio1 = buffer1.nioByteBuffer();
+
+ int len = nio0.remaining();
+ assertEquals(nio0.remaining(), nio1.remaining());
+ for (int i = 0; i < len; i ++) {
+ assertEquals(nio0.get(), nio1.get());
+ }
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java
new file mode 100644
index 0000000000..7aa37efc58
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java
@@ -0,0 +1,28 @@
+package org.apache.spark.network;/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.RpcHandler;
+
+/** Test RpcHandler which always returns a zero-sized success. */
+public class NoOpRpcHandler implements RpcHandler {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ callback.onSuccess(new byte[0]);
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
new file mode 100644
index 0000000000..43dc0cf8c7
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import io.netty.channel.embedded.EmbeddedChannel;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.spark.network.protocol.Message;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.MessageDecoder;
+import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.util.NettyUtils;
+
+public class ProtocolSuite {
+ private void testServerToClient(Message msg) {
+ EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder());
+ serverChannel.writeOutbound(msg);
+
+ EmbeddedChannel clientChannel = new EmbeddedChannel(
+ NettyUtils.createFrameDecoder(), new MessageDecoder());
+
+ while (!serverChannel.outboundMessages().isEmpty()) {
+ clientChannel.writeInbound(serverChannel.readOutbound());
+ }
+
+ assertEquals(1, clientChannel.inboundMessages().size());
+ assertEquals(msg, clientChannel.readInbound());
+ }
+
+ private void testClientToServer(Message msg) {
+ EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder());
+ clientChannel.writeOutbound(msg);
+
+ EmbeddedChannel serverChannel = new EmbeddedChannel(
+ NettyUtils.createFrameDecoder(), new MessageDecoder());
+
+ while (!clientChannel.outboundMessages().isEmpty()) {
+ serverChannel.writeInbound(clientChannel.readOutbound());
+ }
+
+ assertEquals(1, serverChannel.inboundMessages().size());
+ assertEquals(msg, serverChannel.readInbound());
+ }
+
+ @Test
+ public void requests() {
+ testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2)));
+ testClientToServer(new RpcRequest(12345, new byte[0]));
+ testClientToServer(new RpcRequest(12345, new byte[100]));
+ }
+
+ @Test
+ public void responses() {
+ testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10)));
+ testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0)));
+ testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error"));
+ testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), ""));
+ testServerToClient(new RpcResponse(12345, new byte[0]));
+ testServerToClient(new RpcResponse(12345, new byte[1000]));
+ testServerToClient(new RpcFailure(0, "this is an error"));
+ testServerToClient(new RpcFailure(0, ""));
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
new file mode 100644
index 0000000000..9f216dd2d7
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Sets;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.DefaultStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.TransportConf;
+
+public class RpcIntegrationSuite {
+ static TransportServer server;
+ static TransportClientFactory clientFactory;
+ static RpcHandler rpcHandler;
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+ rpcHandler = new RpcHandler() {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ String msg = new String(message, Charsets.UTF_8);
+ String[] parts = msg.split("/");
+ if (parts[0].equals("hello")) {
+ callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8));
+ } else if (parts[0].equals("return error")) {
+ callback.onFailure(new RuntimeException("Returned: " + parts[1]));
+ } else if (parts[0].equals("throw error")) {
+ throw new RuntimeException("Thrown: " + parts[1]);
+ }
+ }
+ };
+ TransportContext context = new TransportContext(conf, new DefaultStreamManager(), rpcHandler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ server.close();
+ clientFactory.close();
+ }
+
+ class RpcResult {
+ public Set<String> successMessages;
+ public Set<String> errorMessages;
+ }
+
+ private RpcResult sendRPC(String ... commands) throws Exception {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ final Semaphore sem = new Semaphore(0);
+
+ final RpcResult res = new RpcResult();
+ res.successMessages = Collections.synchronizedSet(new HashSet<String>());
+ res.errorMessages = Collections.synchronizedSet(new HashSet<String>());
+
+ RpcResponseCallback callback = new RpcResponseCallback() {
+ @Override
+ public void onSuccess(byte[] message) {
+ res.successMessages.add(new String(message, Charsets.UTF_8));
+ sem.release();
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ res.errorMessages.add(e.getMessage());
+ sem.release();
+ }
+ };
+
+ for (String command : commands) {
+ client.sendRpc(command.getBytes(Charsets.UTF_8), callback);
+ }
+
+ if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server");
+ }
+ client.close();
+ return res;
+ }
+
+ @Test
+ public void singleRPC() throws Exception {
+ RpcResult res = sendRPC("hello/Aaron");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!"));
+ assertTrue(res.errorMessages.isEmpty());
+ }
+
+ @Test
+ public void doubleRPC() throws Exception {
+ RpcResult res = sendRPC("hello/Aaron", "hello/Reynold");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!"));
+ assertTrue(res.errorMessages.isEmpty());
+ }
+
+ @Test
+ public void returnErrorRPC() throws Exception {
+ RpcResult res = sendRPC("return error/OK");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK"));
+ }
+
+ @Test
+ public void throwErrorRPC() throws Exception {
+ RpcResult res = sendRPC("throw error/uh-oh");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: uh-oh"));
+ }
+
+ @Test
+ public void doubleTrouble() throws Exception {
+ RpcResult res = sendRPC("return error/OK", "throw error/uh-oh");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK", "Thrown: uh-oh"));
+ }
+
+ @Test
+ public void sendSuccessAndFailure() throws Exception {
+ RpcResult res = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Bob!", "Hello, Builder!"));
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !"));
+ }
+
+ private void assertErrorsContain(Set<String> errors, Set<String> contains) {
+ assertEquals(contains.size(), errors.size());
+
+ Set<String> remainingErrors = Sets.newHashSet(errors);
+ for (String contain : contains) {
+ Iterator<String> it = remainingErrors.iterator();
+ boolean foundMatch = false;
+ while (it.hasNext()) {
+ if (it.next().contains(contain)) {
+ it.remove();
+ foundMatch = true;
+ break;
+ }
+ }
+ assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch);
+ }
+
+ assertTrue(remainingErrors.isEmpty());
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java b/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java
new file mode 100644
index 0000000000..f4e0a2426a
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.util.NoSuchElementException;
+
+import org.apache.spark.network.util.ConfigProvider;
+
+/** Uses System properties to obtain config values. */
+public class SystemPropertyConfigProvider extends ConfigProvider {
+ @Override
+ public String get(String name) {
+ String value = System.getProperty(name);
+ if (value == null) {
+ throw new NoSuchElementException(name);
+ }
+ return value;
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java
new file mode 100644
index 0000000000..38113a918f
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1).
+ *
+ * Used for testing.
+ */
+public class TestManagedBuffer extends ManagedBuffer {
+
+ private final int len;
+ private NettyManagedBuffer underlying;
+
+ public TestManagedBuffer(int len) {
+ Preconditions.checkArgument(len <= Byte.MAX_VALUE);
+ this.len = len;
+ byte[] byteArray = new byte[len];
+ for (int i = 0; i < len; i ++) {
+ byteArray[i] = (byte) i;
+ }
+ this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray));
+ }
+
+
+ @Override
+ public long size() {
+ return underlying.size();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return underlying.nioByteBuffer();
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ return underlying.createInputStream();
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ underlying.retain();
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ underlying.release();
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return underlying.convertToNetty();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ManagedBuffer) {
+ try {
+ ByteBuffer nioBuf = ((ManagedBuffer) other).nioByteBuffer();
+ if (nioBuf.remaining() != len) {
+ return false;
+ } else {
+ for (int i = 0; i < len; i ++) {
+ if (nioBuf.get() != i) {
+ return false;
+ }
+ }
+ return true;
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ return false;
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/TestUtils.java b/network/common/src/test/java/org/apache/spark/network/TestUtils.java
new file mode 100644
index 0000000000..56a2b805f1
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/TestUtils.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.net.InetAddress;
+
+public class TestUtils {
+ public static String getLocalHost() {
+ try {
+ return InetAddress.getLocalHost().getHostAddress();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
new file mode 100644
index 0000000000..3ef964616f
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.util.concurrent.TimeoutException;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.DefaultStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+public class TransportClientFactorySuite {
+ private TransportConf conf;
+ private TransportContext context;
+ private TransportServer server1;
+ private TransportServer server2;
+
+ @Before
+ public void setUp() {
+ conf = new TransportConf(new SystemPropertyConfigProvider());
+ StreamManager streamManager = new DefaultStreamManager();
+ RpcHandler rpcHandler = new NoOpRpcHandler();
+ context = new TransportContext(conf, streamManager, rpcHandler);
+ server1 = context.createServer();
+ server2 = context.createServer();
+ }
+
+ @After
+ public void tearDown() {
+ JavaUtils.closeQuietly(server1);
+ JavaUtils.closeQuietly(server2);
+ }
+
+ @Test
+ public void createAndReuseBlockClients() throws TimeoutException {
+ TransportClientFactory factory = context.createClientFactory();
+ TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ TransportClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
+ assertTrue(c1.isActive());
+ assertTrue(c3.isActive());
+ assertTrue(c1 == c2);
+ assertTrue(c1 != c3);
+ factory.close();
+ }
+
+ @Test
+ public void neverReturnInactiveClients() throws Exception {
+ TransportClientFactory factory = context.createClientFactory();
+ TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ c1.close();
+
+ long start = System.currentTimeMillis();
+ while (c1.isActive() && (System.currentTimeMillis() - start) < 3000) {
+ Thread.sleep(10);
+ }
+ assertFalse(c1.isActive());
+
+ TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ assertFalse(c1 == c2);
+ assertTrue(c2.isActive());
+ factory.close();
+ }
+
+ @Test
+ public void closeBlockClientsWithFactory() throws TimeoutException {
+ TransportClientFactory factory = context.createClientFactory();
+ TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
+ assertTrue(c1.isActive());
+ assertTrue(c2.isActive());
+ factory.close();
+ assertFalse(c1.isActive());
+ assertFalse(c2.isActive());
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
new file mode 100644
index 0000000000..17a03ebe88
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import io.netty.channel.local.LocalChannel;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamChunkId;
+
+public class TransportResponseHandlerSuite {
+ @Test
+ public void handleSuccessfulFetch() {
+ StreamChunkId streamChunkId = new StreamChunkId(1, 0);
+
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ handler.addFetchRequest(streamChunkId, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123)));
+ verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void handleFailedFetch() {
+ StreamChunkId streamChunkId = new StreamChunkId(1, 0);
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ handler.addFetchRequest(streamChunkId, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg"));
+ verify(callback, times(1)).onFailure(eq(0), (Throwable) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void clearAllOutstandingRequests() {
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ handler.addFetchRequest(new StreamChunkId(1, 0), callback);
+ handler.addFetchRequest(new StreamChunkId(1, 1), callback);
+ handler.addFetchRequest(new StreamChunkId(1, 2), callback);
+ assertEquals(3, handler.numOutstandingRequests());
+
+ handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12)));
+ handler.exceptionCaught(new Exception("duh duh duhhhh"));
+
+ // should fail both b2 and b3
+ verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any());
+ verify(callback, times(1)).onFailure(eq(1), (Throwable) any());
+ verify(callback, times(1)).onFailure(eq(2), (Throwable) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void handleSuccessfulRPC() {
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+ handler.addRpcRequest(12345, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored
+ assertEquals(1, handler.numOutstandingRequests());
+
+ byte[] arr = new byte[10];
+ handler.handle(new RpcResponse(12345, arr));
+ verify(callback, times(1)).onSuccess(eq(arr));
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void handleFailedRPC() {
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+ handler.addRpcRequest(12345, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new RpcFailure(54321, "uh-oh!")); // should be ignored
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new RpcFailure(12345, "oh no"));
+ verify(callback, times(1)).onFailure((Throwable) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+}