aboutsummaryrefslogtreecommitdiff
path: root/common/network-common
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-02-28 17:25:07 -0800
committerReynold Xin <rxin@databricks.com>2016-02-28 17:25:07 -0800
commit9e01dcc6446f8648e61062f8afe62589b9d4b5ab (patch)
treeae3c7015e950de315a490ce58c181671bfd12907 /common/network-common
parentcca79fad66c4315b0ed6de59fd87700a540e6646 (diff)
downloadspark-9e01dcc6446f8648e61062f8afe62589b9d4b5ab.tar.gz
spark-9e01dcc6446f8648e61062f8afe62589b9d4b5ab.tar.bz2
spark-9e01dcc6446f8648e61062f8afe62589b9d4b5ab.zip
[SPARK-13529][BUILD] Move network/* modules into common/network-*
## What changes were proposed in this pull request? As the title says, this moves the three modules currently in network/ into common/network-*. This removes one top level, non-user-facing folder. ## How was this patch tested? Compilation and existing tests. We should run both SBT and Maven. Author: Reynold Xin <rxin@databricks.com> Closes #11409 from rxin/SPARK-13529.
Diffstat (limited to 'common/network-common')
-rw-r--r--common/network-common/pom.xml103
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/TransportContext.java166
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java154
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java111
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java75
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java76
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java75
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java31
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java47
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java32
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java40
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java86
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java321
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java34
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java264
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java251
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java54
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java32
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java76
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java71
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java89
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/Encodable.java41
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java92
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java73
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java82
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java93
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java135
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java80
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java25
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java25
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java74
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java87
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java87
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java73
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java80
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java78
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java92
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java109
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java291
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java33
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java78
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java158
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java49
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java35
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java162
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java200
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java39
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java40
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java143
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java100
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java86
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java163
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java209
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java151
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java36
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java69
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java67
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java52
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/IOMode.java27
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java303
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java105
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java41
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java139
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java34
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java169
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java227
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java244
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java127
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java288
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java215
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java349
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java109
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/TestUtils.java30
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java214
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java146
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java157
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java476
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java50
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java258
-rw-r--r--common/network-common/src/test/resources/log4j.properties27
80 files changed, 9410 insertions, 0 deletions
diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
new file mode 100644
index 0000000000..bd507c2cb6
--- /dev/null
+++ b/common/network-common/pom.xml
@@ -0,0 +1,103 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent_2.11</artifactId>
+ <version>2.0.0-SNAPSHOT</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-network-common_2.11</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project Networking</name>
+ <url>http://spark.apache.org/</url>
+ <properties>
+ <sbt.project.name>network-common</sbt.project.name>
+ </properties>
+
+ <dependencies>
+ <!-- Core dependencies -->
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-all</artifactId>
+ </dependency>
+
+ <!-- Provided dependencies -->
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.google.code.findbugs</groupId>
+ <artifactId>jsr305</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <scope>compile</scope>
+ </dependency>
+
+ <!-- Test dependencies -->
+ <dependency>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-test-tags_${scala.binary.version}</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-log4j12</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+ <plugins>
+ <!-- Create a test-jar so network-shuffle can depend on our test utilities. -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>test-jar-on-test-compile</id>
+ <phase>test-compile</phase>
+ <goals>
+ <goal>test-jar</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
new file mode 100644
index 0000000000..238710d172
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import io.netty.channel.Channel;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.timeout.IdleStateHandler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.MessageDecoder;
+import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportChannelHandler;
+import org.apache.spark.network.server.TransportRequestHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.util.TransportFrameDecoder;
+
+/**
+ * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to
+ * setup Netty Channel pipelines with a {@link org.apache.spark.network.server.TransportChannelHandler}.
+ *
+ * There are two communication protocols that the TransportClient provides, control-plane RPCs and
+ * data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the
+ * TransportContext (i.e., by a user-provided handler), and it is responsible for setting up streams
+ * which can be streamed through the data plane in chunks using zero-copy IO.
+ *
+ * The TransportServer and TransportClientFactory both create a TransportChannelHandler for each
+ * channel. As each TransportChannelHandler contains a TransportClient, this enables server
+ * processes to send messages back to the client on an existing channel.
+ */
+public class TransportContext {
+ private final Logger logger = LoggerFactory.getLogger(TransportContext.class);
+
+ private final TransportConf conf;
+ private final RpcHandler rpcHandler;
+ private final boolean closeIdleConnections;
+
+ private final MessageEncoder encoder;
+ private final MessageDecoder decoder;
+
+ public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
+ this(conf, rpcHandler, false);
+ }
+
+ public TransportContext(
+ TransportConf conf,
+ RpcHandler rpcHandler,
+ boolean closeIdleConnections) {
+ this.conf = conf;
+ this.rpcHandler = rpcHandler;
+ this.encoder = new MessageEncoder();
+ this.decoder = new MessageDecoder();
+ this.closeIdleConnections = closeIdleConnections;
+ }
+
+ /**
+ * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning
+ * a new Client. Bootstraps will be executed synchronously, and must run successfully in order
+ * to create a Client.
+ */
+ public TransportClientFactory createClientFactory(List<TransportClientBootstrap> bootstraps) {
+ return new TransportClientFactory(this, bootstraps);
+ }
+
+ public TransportClientFactory createClientFactory() {
+ return createClientFactory(Lists.<TransportClientBootstrap>newArrayList());
+ }
+
+ /** Create a server which will attempt to bind to a specific port. */
+ public TransportServer createServer(int port, List<TransportServerBootstrap> bootstraps) {
+ return new TransportServer(this, null, port, rpcHandler, bootstraps);
+ }
+
+ /** Create a server which will attempt to bind to a specific host and port. */
+ public TransportServer createServer(
+ String host, int port, List<TransportServerBootstrap> bootstraps) {
+ return new TransportServer(this, host, port, rpcHandler, bootstraps);
+ }
+
+ /** Creates a new server, binding to any available ephemeral port. */
+ public TransportServer createServer(List<TransportServerBootstrap> bootstraps) {
+ return createServer(0, bootstraps);
+ }
+
+ public TransportServer createServer() {
+ return createServer(0, Lists.<TransportServerBootstrap>newArrayList());
+ }
+
+ public TransportChannelHandler initializePipeline(SocketChannel channel) {
+ return initializePipeline(channel, rpcHandler);
+ }
+
+ /**
+ * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and
+ * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or
+ * response messages.
+ *
+ * @param channel The channel to initialize.
+ * @param channelRpcHandler The RPC handler to use for the channel.
+ *
+ * @return Returns the created TransportChannelHandler, which includes a TransportClient that can
+ * be used to communicate on this channel. The TransportClient is directly associated with a
+ * ChannelHandler to ensure all users of the same channel get the same TransportClient object.
+ */
+ public TransportChannelHandler initializePipeline(
+ SocketChannel channel,
+ RpcHandler channelRpcHandler) {
+ try {
+ TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
+ channel.pipeline()
+ .addLast("encoder", encoder)
+ .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
+ .addLast("decoder", decoder)
+ .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
+ // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
+ // would require more logic to guarantee if this were not part of the same event loop.
+ .addLast("handler", channelHandler);
+ return channelHandler;
+ } catch (RuntimeException e) {
+ logger.error("Error while initializing Netty pipeline", e);
+ throw e;
+ }
+ }
+
+ /**
+ * Creates the server- and client-side handler which is used to handle both RequestMessages and
+ * ResponseMessages. The channel is expected to have been successfully created, though certain
+ * properties (such as the remoteAddress()) may not be available yet.
+ */
+ private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) {
+ TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
+ TransportClient client = new TransportClient(channel, responseHandler);
+ TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
+ rpcHandler);
+ return new TransportChannelHandler(client, responseHandler, requestHandler,
+ conf.connectionTimeoutMs(), closeIdleConnections);
+ }
+
+ public TransportConf getConf() { return conf; }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
new file mode 100644
index 0000000000..844eff4f4c
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+
+import com.google.common.base.Objects;
+import com.google.common.io.ByteStreams;
+import io.netty.channel.DefaultFileRegion;
+
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A {@link ManagedBuffer} backed by a segment in a file.
+ */
+public final class FileSegmentManagedBuffer extends ManagedBuffer {
+ private final TransportConf conf;
+ private final File file;
+ private final long offset;
+ private final long length;
+
+ public FileSegmentManagedBuffer(TransportConf conf, File file, long offset, long length) {
+ this.conf = conf;
+ this.file = file;
+ this.offset = offset;
+ this.length = length;
+ }
+
+ @Override
+ public long size() {
+ return length;
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ FileChannel channel = null;
+ try {
+ channel = new RandomAccessFile(file, "r").getChannel();
+ // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead.
+ if (length < conf.memoryMapBytes()) {
+ ByteBuffer buf = ByteBuffer.allocate((int) length);
+ channel.position(offset);
+ while (buf.remaining() != 0) {
+ if (channel.read(buf) == -1) {
+ throw new IOException(String.format("Reached EOF before filling buffer\n" +
+ "offset=%s\nfile=%s\nbuf.remaining=%s",
+ offset, file.getAbsoluteFile(), buf.remaining()));
+ }
+ }
+ buf.flip();
+ return buf;
+ } else {
+ return channel.map(FileChannel.MapMode.READ_ONLY, offset, length);
+ }
+ } catch (IOException e) {
+ try {
+ if (channel != null) {
+ long size = channel.size();
+ throw new IOException("Error in reading " + this + " (actual file length " + size + ")",
+ e);
+ }
+ } catch (IOException ignored) {
+ // ignore
+ }
+ throw new IOException("Error in opening " + this, e);
+ } finally {
+ JavaUtils.closeQuietly(channel);
+ }
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ FileInputStream is = null;
+ try {
+ is = new FileInputStream(file);
+ ByteStreams.skipFully(is, offset);
+ return new LimitedInputStream(is, length);
+ } catch (IOException e) {
+ try {
+ if (is != null) {
+ long size = file.length();
+ throw new IOException("Error in reading " + this + " (actual file length " + size + ")",
+ e);
+ }
+ } catch (IOException ignored) {
+ // ignore
+ } finally {
+ JavaUtils.closeQuietly(is);
+ }
+ throw new IOException("Error in opening " + this, e);
+ } catch (RuntimeException e) {
+ JavaUtils.closeQuietly(is);
+ throw e;
+ }
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ if (conf.lazyFileDescriptor()) {
+ return new LazyFileRegion(file, offset, length);
+ } else {
+ FileChannel fileChannel = new FileInputStream(file).getChannel();
+ return new DefaultFileRegion(fileChannel, offset, length);
+ }
+ }
+
+ public File getFile() { return file; }
+
+ public long getOffset() { return offset; }
+
+ public long getLength() { return length; }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("file", file)
+ .add("offset", offset)
+ .add("length", length)
+ .toString();
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java
new file mode 100644
index 0000000000..162cf6da0d
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.FileInputStream;
+import java.io.File;
+import java.io.IOException;
+import java.nio.channels.FileChannel;
+import java.nio.channels.WritableByteChannel;
+
+import com.google.common.base.Objects;
+import io.netty.channel.FileRegion;
+import io.netty.util.AbstractReferenceCounted;
+
+import org.apache.spark.network.util.JavaUtils;
+
+/**
+ * A FileRegion implementation that only creates the file descriptor when the region is being
+ * transferred. This cannot be used with Epoll because there is no native support for it.
+ *
+ * This is mostly copied from DefaultFileRegion implementation in Netty. In the future, we
+ * should push this into Netty so the native Epoll transport can support this feature.
+ */
+public final class LazyFileRegion extends AbstractReferenceCounted implements FileRegion {
+
+ private final File file;
+ private final long position;
+ private final long count;
+
+ private FileChannel channel;
+
+ private long numBytesTransferred = 0L;
+
+ /**
+ * @param file file to transfer.
+ * @param position start position for the transfer.
+ * @param count number of bytes to transfer starting from position.
+ */
+ public LazyFileRegion(File file, long position, long count) {
+ this.file = file;
+ this.position = position;
+ this.count = count;
+ }
+
+ @Override
+ protected void deallocate() {
+ JavaUtils.closeQuietly(channel);
+ }
+
+ @Override
+ public long position() {
+ return position;
+ }
+
+ @Override
+ public long transfered() {
+ return numBytesTransferred;
+ }
+
+ @Override
+ public long count() {
+ return count;
+ }
+
+ @Override
+ public long transferTo(WritableByteChannel target, long position) throws IOException {
+ if (channel == null) {
+ channel = new FileInputStream(file).getChannel();
+ }
+
+ long count = this.count - position;
+ if (count < 0 || position < 0) {
+ throw new IllegalArgumentException(
+ "position out of range: " + position + " (expected: 0 - " + (count - 1) + ')');
+ }
+
+ if (count == 0) {
+ return 0L;
+ }
+
+ long written = channel.transferTo(this.position + position, count, target);
+ if (written > 0) {
+ numBytesTransferred += written;
+ }
+ return written;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("file", file)
+ .add("position", position)
+ .add("count", count)
+ .toString();
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
new file mode 100644
index 0000000000..1861f8d7fd
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+/**
+ * This interface provides an immutable view for data in the form of bytes. The implementation
+ * should specify how the data is provided:
+ *
+ * - {@link FileSegmentManagedBuffer}: data backed by part of a file
+ * - {@link NioManagedBuffer}: data backed by a NIO ByteBuffer
+ * - {@link NettyManagedBuffer}: data backed by a Netty ByteBuf
+ *
+ * The concrete buffer implementation might be managed outside the JVM garbage collector.
+ * For example, in the case of {@link NettyManagedBuffer}, the buffers are reference counted.
+ * In that case, if the buffer is going to be passed around to a different thread, retain/release
+ * should be called.
+ */
+public abstract class ManagedBuffer {
+
+ /** Number of bytes of the data. */
+ public abstract long size();
+
+ /**
+ * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
+ * returned ByteBuffer should not affect the content of this buffer.
+ */
+ // TODO: Deprecate this, usage may require expensive memory mapping or allocation.
+ public abstract ByteBuffer nioByteBuffer() throws IOException;
+
+ /**
+ * Exposes this buffer's data as an InputStream. The underlying implementation does not
+ * necessarily check for the length of bytes read, so the caller is responsible for making sure
+ * it does not go over the limit.
+ */
+ public abstract InputStream createInputStream() throws IOException;
+
+ /**
+ * Increment the reference count by one if applicable.
+ */
+ public abstract ManagedBuffer retain();
+
+ /**
+ * If applicable, decrement the reference count by one and deallocates the buffer if the
+ * reference count reaches zero.
+ */
+ public abstract ManagedBuffer release();
+
+ /**
+ * Convert the buffer into an Netty object, used to write the data out. The return value is either
+ * a {@link io.netty.buffer.ByteBuf} or a {@link io.netty.channel.FileRegion}.
+ *
+ * If this method returns a ByteBuf, then that buffer's reference count will be incremented and
+ * the caller will be responsible for releasing this new reference.
+ */
+ public abstract Object convertToNetty() throws IOException;
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
new file mode 100644
index 0000000000..4c8802af7a
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufInputStream;
+
+/**
+ * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}.
+ */
+public final class NettyManagedBuffer extends ManagedBuffer {
+ private final ByteBuf buf;
+
+ public NettyManagedBuffer(ByteBuf buf) {
+ this.buf = buf;
+ }
+
+ @Override
+ public long size() {
+ return buf.readableBytes();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return buf.nioBuffer();
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ return new ByteBufInputStream(buf);
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ buf.retain();
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ buf.release();
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return buf.duplicate().retain();
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("buf", buf)
+ .toString();
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java
new file mode 100644
index 0000000000..631d767715
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBufInputStream;
+import io.netty.buffer.Unpooled;
+
+/**
+ * A {@link ManagedBuffer} backed by {@link ByteBuffer}.
+ */
+public class NioManagedBuffer extends ManagedBuffer {
+ private final ByteBuffer buf;
+
+ public NioManagedBuffer(ByteBuffer buf) {
+ this.buf = buf;
+ }
+
+ @Override
+ public long size() {
+ return buf.remaining();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return buf.duplicate();
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ return new ByteBufInputStream(Unpooled.wrappedBuffer(buf));
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return Unpooled.wrappedBuffer(buf);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("buf", buf)
+ .toString();
+ }
+}
+
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java
new file mode 100644
index 0000000000..1fbdcd6780
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * General exception caused by a remote exception while fetching a chunk.
+ */
+public class ChunkFetchFailureException extends RuntimeException {
+ public ChunkFetchFailureException(String errorMsg, Throwable cause) {
+ super(errorMsg, cause);
+ }
+
+ public ChunkFetchFailureException(String errorMsg) {
+ super(errorMsg);
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java
new file mode 100644
index 0000000000..519e6cb470
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Callback for the result of a single chunk result. For a single stream, the callbacks are
+ * guaranteed to be called by the same thread in the same order as the requests for chunks were
+ * made.
+ *
+ * Note that if a general stream failure occurs, all outstanding chunk requests may be failed.
+ */
+public interface ChunkReceivedCallback {
+ /**
+ * Called upon receipt of a particular chunk.
+ *
+ * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this
+ * call returns. You must therefore either retain() the buffer or copy its contents before
+ * returning.
+ */
+ void onSuccess(int chunkIndex, ManagedBuffer buffer);
+
+ /**
+ * Called upon failure to fetch a particular chunk. Note that this may actually be called due
+ * to failure to fetch a prior chunk in this stream.
+ *
+ * After receiving a failure, the stream may or may not be valid. The client should not assume
+ * that the server's side of the stream has been closed.
+ */
+ void onFailure(int chunkIndex, Throwable e);
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
new file mode 100644
index 0000000000..47e93f9846
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.nio.ByteBuffer;
+
+/**
+ * Callback for the result of a single RPC. This will be invoked once with either success or
+ * failure.
+ */
+public interface RpcResponseCallback {
+ /** Successful serialized result from server. */
+ void onSuccess(ByteBuffer response);
+
+ /** Exception either propagated from server or raised on client side. */
+ void onFailure(Throwable e);
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java
new file mode 100644
index 0000000000..29e6a30dc1
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+/**
+ * Callback for streaming data. Stream data will be offered to the {@link #onData(String, ByteBuffer)}
+ * method as it arrives. Once all the stream data is received, {@link #onComplete(String)} will be
+ * called.
+ * <p>
+ * The network library guarantees that a single thread will call these methods at a time, but
+ * different call may be made by different threads.
+ */
+public interface StreamCallback {
+ /** Called upon receipt of stream data. */
+ void onData(String streamId, ByteBuffer buf) throws IOException;
+
+ /** Called when all data from the stream has been received. */
+ void onComplete(String streamId) throws IOException;
+
+ /** Called if there's an error reading data from the stream. */
+ void onFailure(String streamId, Throwable cause) throws IOException;
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
new file mode 100644
index 0000000000..88ba3ccebd
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.nio.ByteBuffer;
+import java.nio.channels.ClosedChannelException;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.util.TransportFrameDecoder;
+
+/**
+ * An interceptor that is registered with the frame decoder to feed stream data to a
+ * callback.
+ */
+class StreamInterceptor implements TransportFrameDecoder.Interceptor {
+
+ private final TransportResponseHandler handler;
+ private final String streamId;
+ private final long byteCount;
+ private final StreamCallback callback;
+
+ private volatile long bytesRead;
+
+ StreamInterceptor(
+ TransportResponseHandler handler,
+ String streamId,
+ long byteCount,
+ StreamCallback callback) {
+ this.handler = handler;
+ this.streamId = streamId;
+ this.byteCount = byteCount;
+ this.callback = callback;
+ this.bytesRead = 0;
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) throws Exception {
+ handler.deactivateStream();
+ callback.onFailure(streamId, cause);
+ }
+
+ @Override
+ public void channelInactive() throws Exception {
+ handler.deactivateStream();
+ callback.onFailure(streamId, new ClosedChannelException());
+ }
+
+ @Override
+ public boolean handle(ByteBuf buf) throws Exception {
+ int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead);
+ ByteBuffer nioBuffer = buf.readSlice(toRead).nioBuffer();
+
+ int available = nioBuffer.remaining();
+ callback.onData(streamId, nioBuffer);
+ bytesRead += available;
+ if (bytesRead > byteCount) {
+ RuntimeException re = new IllegalStateException(String.format(
+ "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead));
+ callback.onFailure(streamId, re);
+ handler.deactivateStream();
+ throw re;
+ } else if (bytesRead == byteCount) {
+ handler.deactivateStream();
+ callback.onComplete(streamId);
+ }
+
+ return bytesRead != byteCount;
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java
new file mode 100644
index 0000000000..e15f096d36
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -0,0 +1,321 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.net.SocketAddress;
+import java.nio.ByteBuffer;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Objects;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.util.concurrent.SettableFuture;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.OneWayMessage;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamRequest;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow
+ * efficient transfer of a large amount of data, broken up into chunks with size ranging from
+ * hundreds of KB to a few MB.
+ *
+ * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane),
+ * the actual setup of the streams is done outside the scope of the transport layer. The convenience
+ * method "sendRPC" is provided to enable control plane communication between the client and server
+ * to perform this setup.
+ *
+ * For example, a typical workflow might be:
+ * client.sendRPC(new OpenFile("/foo")) --&gt; returns StreamId = 100
+ * client.fetchChunk(streamId = 100, chunkIndex = 0, callback)
+ * client.fetchChunk(streamId = 100, chunkIndex = 1, callback)
+ * ...
+ * client.sendRPC(new CloseStream(100))
+ *
+ * Construct an instance of TransportClient using {@link TransportClientFactory}. A single
+ * TransportClient may be used for multiple streams, but any given stream must be restricted to a
+ * single client, in order to avoid out-of-order responses.
+ *
+ * NB: This class is used to make requests to the server, while {@link TransportResponseHandler} is
+ * responsible for handling responses from the server.
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+public class TransportClient implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(TransportClient.class);
+
+ private final Channel channel;
+ private final TransportResponseHandler handler;
+ @Nullable private String clientId;
+ private volatile boolean timedOut;
+
+ public TransportClient(Channel channel, TransportResponseHandler handler) {
+ this.channel = Preconditions.checkNotNull(channel);
+ this.handler = Preconditions.checkNotNull(handler);
+ this.timedOut = false;
+ }
+
+ public Channel getChannel() {
+ return channel;
+ }
+
+ public boolean isActive() {
+ return !timedOut && (channel.isOpen() || channel.isActive());
+ }
+
+ public SocketAddress getSocketAddress() {
+ return channel.remoteAddress();
+ }
+
+ /**
+ * Returns the ID used by the client to authenticate itself when authentication is enabled.
+ *
+ * @return The client ID, or null if authentication is disabled.
+ */
+ public String getClientId() {
+ return clientId;
+ }
+
+ /**
+ * Sets the authenticated client ID. This is meant to be used by the authentication layer.
+ *
+ * Trying to set a different client ID after it's been set will result in an exception.
+ */
+ public void setClientId(String id) {
+ Preconditions.checkState(clientId == null, "Client ID has already been set.");
+ this.clientId = id;
+ }
+
+ /**
+ * Requests a single chunk from the remote side, from the pre-negotiated streamId.
+ *
+ * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though
+ * some streams may not support this.
+ *
+ * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed
+ * to be returned in the same order that they were requested, assuming only a single
+ * TransportClient is used to fetch the chunks.
+ *
+ * @param streamId Identifier that refers to a stream in the remote StreamManager. This should
+ * be agreed upon by client and server beforehand.
+ * @param chunkIndex 0-based index of the chunk to fetch
+ * @param callback Callback invoked upon successful receipt of chunk, or upon any failure.
+ */
+ public void fetchChunk(
+ long streamId,
+ final int chunkIndex,
+ final ChunkReceivedCallback callback) {
+ final String serverAddr = NettyUtils.getRemoteAddress(channel);
+ final long startTime = System.currentTimeMillis();
+ logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr);
+
+ final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
+ handler.addFetchRequest(streamChunkId, callback);
+
+ channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr,
+ timeTaken);
+ } else {
+ String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId,
+ serverAddr, future.cause());
+ logger.error(errorMsg, future.cause());
+ handler.removeFetchRequest(streamChunkId);
+ channel.close();
+ try {
+ callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
+ }
+ }
+ });
+ }
+
+ /**
+ * Request to stream the data with the given stream ID from the remote end.
+ *
+ * @param streamId The stream to fetch.
+ * @param callback Object to call with the stream data.
+ */
+ public void stream(final String streamId, final StreamCallback callback) {
+ final String serverAddr = NettyUtils.getRemoteAddress(channel);
+ final long startTime = System.currentTimeMillis();
+ logger.debug("Sending stream request for {} to {}", streamId, serverAddr);
+
+ // Need to synchronize here so that the callback is added to the queue and the RPC is
+ // written to the socket atomically, so that callbacks are called in the right order
+ // when responses arrive.
+ synchronized (this) {
+ handler.addStreamCallback(callback);
+ channel.writeAndFlush(new StreamRequest(streamId)).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.trace("Sending request for {} to {} took {} ms", streamId, serverAddr,
+ timeTaken);
+ } else {
+ String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId,
+ serverAddr, future.cause());
+ logger.error(errorMsg, future.cause());
+ channel.close();
+ try {
+ callback.onFailure(streamId, new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
+ }
+ }
+ });
+ }
+ }
+
+ /**
+ * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked
+ * with the server's response or upon any failure.
+ *
+ * @param message The message to send.
+ * @param callback Callback to handle the RPC's reply.
+ * @return The RPC's id.
+ */
+ public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) {
+ final String serverAddr = NettyUtils.getRemoteAddress(channel);
+ final long startTime = System.currentTimeMillis();
+ logger.trace("Sending RPC to {}", serverAddr);
+
+ final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
+ handler.addRpcRequest(requestId, callback);
+
+ channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken);
+ } else {
+ String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
+ serverAddr, future.cause());
+ logger.error(errorMsg, future.cause());
+ handler.removeRpcRequest(requestId);
+ channel.close();
+ try {
+ callback.onFailure(new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
+ }
+ }
+ });
+
+ return requestId;
+ }
+
+ /**
+ * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
+ * a specified timeout for a response.
+ */
+ public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) {
+ final SettableFuture<ByteBuffer> result = SettableFuture.create();
+
+ sendRpc(message, new RpcResponseCallback() {
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ result.set(response);
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ result.setException(e);
+ }
+ });
+
+ try {
+ return result.get(timeoutMs, TimeUnit.MILLISECONDS);
+ } catch (ExecutionException e) {
+ throw Throwables.propagate(e.getCause());
+ } catch (Exception e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the
+ * message, and no delivery guarantees are made.
+ *
+ * @param message The message to send.
+ */
+ public void send(ByteBuffer message) {
+ channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message)));
+ }
+
+ /**
+ * Removes any state associated with the given RPC.
+ *
+ * @param requestId The RPC id returned by {@link #sendRpc(ByteBuffer, RpcResponseCallback)}.
+ */
+ public void removeRpcRequest(long requestId) {
+ handler.removeRpcRequest(requestId);
+ }
+
+ /** Mark this channel as having timed out. */
+ public void timeOut() {
+ this.timedOut = true;
+ }
+
+ @VisibleForTesting
+ public TransportResponseHandler getHandler() {
+ return handler;
+ }
+
+ @Override
+ public void close() {
+ // close is a local operation and should finish with milliseconds; timeout just to be safe
+ channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("remoteAdress", channel.remoteAddress())
+ .add("clientId", clientId)
+ .add("isActive", isActive())
+ .toString();
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
new file mode 100644
index 0000000000..eaae2ee043
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import io.netty.channel.Channel;
+
+/**
+ * A bootstrap which is executed on a TransportClient before it is returned to the user.
+ * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per-
+ * connection basis.
+ *
+ * Since connections (and TransportClients) are reused as much as possible, it is generally
+ * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with
+ * the JVM itself.
+ */
+public interface TransportClientBootstrap {
+ /** Performs the bootstrapping operation, throwing an exception on failure. */
+ void doBootstrap(TransportClient client, Channel channel) throws RuntimeException;
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
new file mode 100644
index 0000000000..61bafc8380
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -0,0 +1,264 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicReference;
+
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.Lists;
+import io.netty.bootstrap.Bootstrap;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.server.TransportChannelHandler;
+import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Factory for creating {@link TransportClient}s by using createClient.
+ *
+ * The factory maintains a connection pool to other hosts and should return the same
+ * TransportClient for the same remote host. It also shares a single worker thread pool for
+ * all TransportClients.
+ *
+ * TransportClients will be reused whenever possible. Prior to completing the creation of a new
+ * TransportClient, all given {@link TransportClientBootstrap}s will be run.
+ */
+public class TransportClientFactory implements Closeable {
+
+ /** A simple data structure to track the pool of clients between two peer nodes. */
+ private static class ClientPool {
+ TransportClient[] clients;
+ Object[] locks;
+
+ public ClientPool(int size) {
+ clients = new TransportClient[size];
+ locks = new Object[size];
+ for (int i = 0; i < size; i++) {
+ locks[i] = new Object();
+ }
+ }
+ }
+
+ private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
+
+ private final TransportContext context;
+ private final TransportConf conf;
+ private final List<TransportClientBootstrap> clientBootstraps;
+ private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
+
+ /** Random number generator for picking connections between peers. */
+ private final Random rand;
+ private final int numConnectionsPerPeer;
+
+ private final Class<? extends Channel> socketChannelClass;
+ private EventLoopGroup workerGroup;
+ private PooledByteBufAllocator pooledAllocator;
+
+ public TransportClientFactory(
+ TransportContext context,
+ List<TransportClientBootstrap> clientBootstraps) {
+ this.context = Preconditions.checkNotNull(context);
+ this.conf = context.getConf();
+ this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
+ this.connectionPool = new ConcurrentHashMap<SocketAddress, ClientPool>();
+ this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
+ this.rand = new Random();
+
+ IOMode ioMode = IOMode.valueOf(conf.ioMode());
+ this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
+ // TODO: Make thread pool name configurable.
+ this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client");
+ this.pooledAllocator = NettyUtils.createPooledByteBufAllocator(
+ conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads());
+ }
+
+ /**
+ * Create a {@link TransportClient} connecting to the given remote host / port.
+ *
+ * We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer)
+ * and randomly picks one to use. If no client was previously created in the randomly selected
+ * spot, this function creates a new client and places it there.
+ *
+ * Prior to the creation of a new TransportClient, we will execute all
+ * {@link TransportClientBootstrap}s that are registered with this factory.
+ *
+ * This blocks until a connection is successfully established and fully bootstrapped.
+ *
+ * Concurrency: This method is safe to call from multiple threads.
+ */
+ public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
+ // Get connection from the connection pool first.
+ // If it is not found or not active, create a new one.
+ final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
+
+ // Create the ClientPool if we don't have it yet.
+ ClientPool clientPool = connectionPool.get(address);
+ if (clientPool == null) {
+ connectionPool.putIfAbsent(address, new ClientPool(numConnectionsPerPeer));
+ clientPool = connectionPool.get(address);
+ }
+
+ int clientIndex = rand.nextInt(numConnectionsPerPeer);
+ TransportClient cachedClient = clientPool.clients[clientIndex];
+
+ if (cachedClient != null && cachedClient.isActive()) {
+ // Make sure that the channel will not timeout by updating the last use time of the
+ // handler. Then check that the client is still alive, in case it timed out before
+ // this code was able to update things.
+ TransportChannelHandler handler = cachedClient.getChannel().pipeline()
+ .get(TransportChannelHandler.class);
+ synchronized (handler) {
+ handler.getResponseHandler().updateTimeOfLastRequest();
+ }
+
+ if (cachedClient.isActive()) {
+ logger.trace("Returning cached connection to {}: {}", address, cachedClient);
+ return cachedClient;
+ }
+ }
+
+ // If we reach here, we don't have an existing connection open. Let's create a new one.
+ // Multiple threads might race here to create new connections. Keep only one of them active.
+ synchronized (clientPool.locks[clientIndex]) {
+ cachedClient = clientPool.clients[clientIndex];
+
+ if (cachedClient != null) {
+ if (cachedClient.isActive()) {
+ logger.trace("Returning cached connection to {}: {}", address, cachedClient);
+ return cachedClient;
+ } else {
+ logger.info("Found inactive connection to {}, creating a new one.", address);
+ }
+ }
+ clientPool.clients[clientIndex] = createClient(address);
+ return clientPool.clients[clientIndex];
+ }
+ }
+
+ /**
+ * Create a completely new {@link TransportClient} to the given remote host / port.
+ * This connection is not pooled.
+ *
+ * As with {@link #createClient(String, int)}, this method is blocking.
+ */
+ public TransportClient createUnmanagedClient(String remoteHost, int remotePort)
+ throws IOException {
+ final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
+ return createClient(address);
+ }
+
+ /** Create a completely new {@link TransportClient} to the remote address. */
+ private TransportClient createClient(InetSocketAddress address) throws IOException {
+ logger.debug("Creating new connection to " + address);
+
+ Bootstrap bootstrap = new Bootstrap();
+ bootstrap.group(workerGroup)
+ .channel(socketChannelClass)
+ // Disable Nagle's Algorithm since we don't want packets to wait
+ .option(ChannelOption.TCP_NODELAY, true)
+ .option(ChannelOption.SO_KEEPALIVE, true)
+ .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
+ .option(ChannelOption.ALLOCATOR, pooledAllocator);
+
+ final AtomicReference<TransportClient> clientRef = new AtomicReference<TransportClient>();
+ final AtomicReference<Channel> channelRef = new AtomicReference<Channel>();
+
+ bootstrap.handler(new ChannelInitializer<SocketChannel>() {
+ @Override
+ public void initChannel(SocketChannel ch) {
+ TransportChannelHandler clientHandler = context.initializePipeline(ch);
+ clientRef.set(clientHandler.getClient());
+ channelRef.set(ch);
+ }
+ });
+
+ // Connect to the remote server
+ long preConnect = System.nanoTime();
+ ChannelFuture cf = bootstrap.connect(address);
+ if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
+ throw new IOException(
+ String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
+ } else if (cf.cause() != null) {
+ throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
+ }
+
+ TransportClient client = clientRef.get();
+ Channel channel = channelRef.get();
+ assert client != null : "Channel future completed successfully with null client";
+
+ // Execute any client bootstraps synchronously before marking the Client as successful.
+ long preBootstrap = System.nanoTime();
+ logger.debug("Connection to {} successful, running bootstraps...", address);
+ try {
+ for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
+ clientBootstrap.doBootstrap(client, channel);
+ }
+ } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
+ long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
+ logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e);
+ client.close();
+ throw Throwables.propagate(e);
+ }
+ long postBootstrap = System.nanoTime();
+
+ logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
+ address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);
+
+ return client;
+ }
+
+ /** Close all connections in the connection pool, and shutdown the worker thread pool. */
+ @Override
+ public void close() {
+ // Go through all clients and close them if they are active.
+ for (ClientPool clientPool : connectionPool.values()) {
+ for (int i = 0; i < clientPool.clients.length; i++) {
+ TransportClient client = clientPool.clients[i];
+ if (client != null) {
+ clientPool.clients[i] = null;
+ JavaUtils.closeQuietly(client);
+ }
+ }
+ }
+ connectionPool.clear();
+
+ if (workerGroup != null) {
+ workerGroup.shutdownGracefully();
+ workerGroup = null;
+ }
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
new file mode 100644
index 0000000000..f0e2004d2d
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicLong;
+
+import com.google.common.annotations.VisibleForTesting;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.ResponseMessage;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamResponse;
+import org.apache.spark.network.server.MessageHandler;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportFrameDecoder;
+
+/**
+ * Handler that processes server responses, in response to requests issued from a
+ * [[TransportClient]]. It works by tracking the list of outstanding requests (and their callbacks).
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
+ private final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class);
+
+ private final Channel channel;
+
+ private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches;
+
+ private final Map<Long, RpcResponseCallback> outstandingRpcs;
+
+ private final Queue<StreamCallback> streamCallbacks;
+ private volatile boolean streamActive;
+
+ /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
+ private final AtomicLong timeOfLastRequestNs;
+
+ public TransportResponseHandler(Channel channel) {
+ this.channel = channel;
+ this.outstandingFetches = new ConcurrentHashMap<StreamChunkId, ChunkReceivedCallback>();
+ this.outstandingRpcs = new ConcurrentHashMap<Long, RpcResponseCallback>();
+ this.streamCallbacks = new ConcurrentLinkedQueue<StreamCallback>();
+ this.timeOfLastRequestNs = new AtomicLong(0);
+ }
+
+ public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
+ updateTimeOfLastRequest();
+ outstandingFetches.put(streamChunkId, callback);
+ }
+
+ public void removeFetchRequest(StreamChunkId streamChunkId) {
+ outstandingFetches.remove(streamChunkId);
+ }
+
+ public void addRpcRequest(long requestId, RpcResponseCallback callback) {
+ updateTimeOfLastRequest();
+ outstandingRpcs.put(requestId, callback);
+ }
+
+ public void removeRpcRequest(long requestId) {
+ outstandingRpcs.remove(requestId);
+ }
+
+ public void addStreamCallback(StreamCallback callback) {
+ timeOfLastRequestNs.set(System.nanoTime());
+ streamCallbacks.offer(callback);
+ }
+
+ @VisibleForTesting
+ public void deactivateStream() {
+ streamActive = false;
+ }
+
+ /**
+ * Fire the failure callback for all outstanding requests. This is called when we have an
+ * uncaught exception or pre-mature connection termination.
+ */
+ private void failOutstandingRequests(Throwable cause) {
+ for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : outstandingFetches.entrySet()) {
+ entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
+ }
+ for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) {
+ entry.getValue().onFailure(cause);
+ }
+
+ // It's OK if new fetches appear, as they will fail immediately.
+ outstandingFetches.clear();
+ outstandingRpcs.clear();
+ }
+
+ @Override
+ public void channelActive() {
+ }
+
+ @Override
+ public void channelInactive() {
+ if (numOutstandingRequests() > 0) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ logger.error("Still have {} requests outstanding when connection from {} is closed",
+ numOutstandingRequests(), remoteAddress);
+ failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
+ }
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) {
+ if (numOutstandingRequests() > 0) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ logger.error("Still have {} requests outstanding when connection from {} is closed",
+ numOutstandingRequests(), remoteAddress);
+ failOutstandingRequests(cause);
+ }
+ }
+
+ @Override
+ public void handle(ResponseMessage message) throws Exception {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ if (message instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
+ ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
+ if (listener == null) {
+ logger.warn("Ignoring response for block {} from {} since it is not outstanding",
+ resp.streamChunkId, remoteAddress);
+ resp.body().release();
+ } else {
+ outstandingFetches.remove(resp.streamChunkId);
+ listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
+ resp.body().release();
+ }
+ } else if (message instanceof ChunkFetchFailure) {
+ ChunkFetchFailure resp = (ChunkFetchFailure) message;
+ ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
+ if (listener == null) {
+ logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding",
+ resp.streamChunkId, remoteAddress, resp.errorString);
+ } else {
+ outstandingFetches.remove(resp.streamChunkId);
+ listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException(
+ "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString));
+ }
+ } else if (message instanceof RpcResponse) {
+ RpcResponse resp = (RpcResponse) message;
+ RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+ if (listener == null) {
+ logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
+ resp.requestId, remoteAddress, resp.body().size());
+ } else {
+ outstandingRpcs.remove(resp.requestId);
+ try {
+ listener.onSuccess(resp.body().nioByteBuffer());
+ } finally {
+ resp.body().release();
+ }
+ }
+ } else if (message instanceof RpcFailure) {
+ RpcFailure resp = (RpcFailure) message;
+ RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+ if (listener == null) {
+ logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
+ resp.requestId, remoteAddress, resp.errorString);
+ } else {
+ outstandingRpcs.remove(resp.requestId);
+ listener.onFailure(new RuntimeException(resp.errorString));
+ }
+ } else if (message instanceof StreamResponse) {
+ StreamResponse resp = (StreamResponse) message;
+ StreamCallback callback = streamCallbacks.poll();
+ if (callback != null) {
+ if (resp.byteCount > 0) {
+ StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
+ callback);
+ try {
+ TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
+ channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
+ frameDecoder.setInterceptor(interceptor);
+ streamActive = true;
+ } catch (Exception e) {
+ logger.error("Error installing stream handler.", e);
+ deactivateStream();
+ }
+ } else {
+ try {
+ callback.onComplete(resp.streamId);
+ } catch (Exception e) {
+ logger.warn("Error in stream handler onComplete().", e);
+ }
+ }
+ } else {
+ logger.error("Could not find callback for StreamResponse.");
+ }
+ } else if (message instanceof StreamFailure) {
+ StreamFailure resp = (StreamFailure) message;
+ StreamCallback callback = streamCallbacks.poll();
+ if (callback != null) {
+ try {
+ callback.onFailure(resp.streamId, new RuntimeException(resp.error));
+ } catch (IOException ioe) {
+ logger.warn("Error in stream failure handler.", ioe);
+ }
+ } else {
+ logger.warn("Stream failure with unknown callback: {}", resp.error);
+ }
+ } else {
+ throw new IllegalStateException("Unknown response type: " + message.type());
+ }
+ }
+
+ /** Returns total number of outstanding requests (fetch requests + rpcs) */
+ public int numOutstandingRequests() {
+ return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() +
+ (streamActive ? 1 : 0);
+ }
+
+ /** Returns the time in nanoseconds of when the last request was sent out. */
+ public long getTimeOfLastRequestNs() {
+ return timeOfLastRequestNs.get();
+ }
+
+ /** Updates the time of the last request to the current system time. */
+ public void updateTimeOfLastRequest() {
+ timeOfLastRequestNs.set(System.nanoTime());
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java
new file mode 100644
index 0000000000..2924218c2f
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Abstract class for messages which optionally contain a body kept in a separate buffer.
+ */
+public abstract class AbstractMessage implements Message {
+ private final ManagedBuffer body;
+ private final boolean isBodyInFrame;
+
+ protected AbstractMessage() {
+ this(null, false);
+ }
+
+ protected AbstractMessage(ManagedBuffer body, boolean isBodyInFrame) {
+ this.body = body;
+ this.isBodyInFrame = isBodyInFrame;
+ }
+
+ @Override
+ public ManagedBuffer body() {
+ return body;
+ }
+
+ @Override
+ public boolean isBodyInFrame() {
+ return isBodyInFrame;
+ }
+
+ protected boolean equals(AbstractMessage other) {
+ return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java
new file mode 100644
index 0000000000..c362c92fc4
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Abstract class for response messages.
+ */
+public abstract class AbstractResponseMessage extends AbstractMessage implements ResponseMessage {
+
+ protected AbstractResponseMessage(ManagedBuffer body, boolean isBodyInFrame) {
+ super(body, isBodyInFrame);
+ }
+
+ public abstract ResponseMessage createFailureResponse(String error);
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
new file mode 100644
index 0000000000..7b28a9a969
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk.
+ */
+public final class ChunkFetchFailure extends AbstractMessage implements ResponseMessage {
+ public final StreamChunkId streamChunkId;
+ public final String errorString;
+
+ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) {
+ this.streamChunkId = streamChunkId;
+ this.errorString = errorString;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchFailure; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ Encoders.Strings.encode(buf, errorString);
+ }
+
+ public static ChunkFetchFailure decode(ByteBuf buf) {
+ StreamChunkId streamChunkId = StreamChunkId.decode(buf);
+ String errorString = Encoders.Strings.decode(buf);
+ return new ChunkFetchFailure(streamChunkId, errorString);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamChunkId, errorString);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchFailure) {
+ ChunkFetchFailure o = (ChunkFetchFailure) other;
+ return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .add("errorString", errorString)
+ .toString();
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
new file mode 100644
index 0000000000..26d063feb5
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single
+ * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
+ */
+public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage {
+ public final StreamChunkId streamChunkId;
+
+ public ChunkFetchRequest(StreamChunkId streamChunkId) {
+ this.streamChunkId = streamChunkId;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchRequest; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength();
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ }
+
+ public static ChunkFetchRequest decode(ByteBuf buf) {
+ return new ChunkFetchRequest(StreamChunkId.decode(buf));
+ }
+
+ @Override
+ public int hashCode() {
+ return streamChunkId.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchRequest) {
+ ChunkFetchRequest o = (ChunkFetchRequest) other;
+ return streamChunkId.equals(o.streamChunkId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .toString();
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
new file mode 100644
index 0000000000..94c2ac9b20
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched.
+ *
+ * Note that the server-side encoding of this messages does NOT include the buffer itself, as this
+ * may be written by Netty in a more efficient manner (i.e., zero-copy write).
+ * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer.
+ */
+public final class ChunkFetchSuccess extends AbstractResponseMessage {
+ public final StreamChunkId streamChunkId;
+
+ public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
+ super(buffer, true);
+ this.streamChunkId = streamChunkId;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchSuccess; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength();
+ }
+
+ /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ }
+
+ @Override
+ public ResponseMessage createFailureResponse(String error) {
+ return new ChunkFetchFailure(streamChunkId, error);
+ }
+
+ /** Decoding uses the given ByteBuf as our data, and will retain() it. */
+ public static ChunkFetchSuccess decode(ByteBuf buf) {
+ StreamChunkId streamChunkId = StreamChunkId.decode(buf);
+ buf.retain();
+ NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate());
+ return new ChunkFetchSuccess(streamChunkId, managedBuf);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamChunkId, body());
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess o = (ChunkFetchSuccess) other;
+ return streamChunkId.equals(o.streamChunkId) && super.equals(o);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .add("buffer", body())
+ .toString();
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encodable.java
new file mode 100644
index 0000000000..b4e299471b
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encodable.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are
+ * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length.
+ *
+ * Encodable objects should provide a static "decode(ByteBuf)" method which is invoked by
+ * {@link MessageDecoder}. During decoding, if the object uses the ByteBuf as its data (rather than
+ * just copying data from it), then you must retain() the ByteBuf.
+ *
+ * Additionally, when adding a new Encodable Message, add it to {@link Message.Type}.
+ */
+public interface Encodable {
+ /** Number of bytes of the encoded form of this object. */
+ int encodedLength();
+
+ /**
+ * Serializes this object by writing into the given ByteBuf.
+ * This method must write exactly encodedLength() bytes.
+ */
+ void encode(ByteBuf buf);
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java
new file mode 100644
index 0000000000..9162d0b977
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+
+import com.google.common.base.Charsets;
+import io.netty.buffer.ByteBuf;
+
+/** Provides a canonical set of Encoders for simple types. */
+public class Encoders {
+
+ /** Strings are encoded with their length followed by UTF-8 bytes. */
+ public static class Strings {
+ public static int encodedLength(String s) {
+ return 4 + s.getBytes(Charsets.UTF_8).length;
+ }
+
+ public static void encode(ByteBuf buf, String s) {
+ byte[] bytes = s.getBytes(Charsets.UTF_8);
+ buf.writeInt(bytes.length);
+ buf.writeBytes(bytes);
+ }
+
+ public static String decode(ByteBuf buf) {
+ int length = buf.readInt();
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return new String(bytes, Charsets.UTF_8);
+ }
+ }
+
+ /** Byte arrays are encoded with their length followed by bytes. */
+ public static class ByteArrays {
+ public static int encodedLength(byte[] arr) {
+ return 4 + arr.length;
+ }
+
+ public static void encode(ByteBuf buf, byte[] arr) {
+ buf.writeInt(arr.length);
+ buf.writeBytes(arr);
+ }
+
+ public static byte[] decode(ByteBuf buf) {
+ int length = buf.readInt();
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return bytes;
+ }
+ }
+
+ /** String arrays are encoded with the number of strings followed by per-String encoding. */
+ public static class StringArrays {
+ public static int encodedLength(String[] strings) {
+ int totalLength = 4;
+ for (String s : strings) {
+ totalLength += Strings.encodedLength(s);
+ }
+ return totalLength;
+ }
+
+ public static void encode(ByteBuf buf, String[] strings) {
+ buf.writeInt(strings.length);
+ for (String s : strings) {
+ Strings.encode(buf, s);
+ }
+ }
+
+ public static String[] decode(ByteBuf buf) {
+ int numStrings = buf.readInt();
+ String[] strings = new String[numStrings];
+ for (int i = 0; i < strings.length; i ++) {
+ strings[i] = Strings.decode(buf);
+ }
+ return strings;
+ }
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java
new file mode 100644
index 0000000000..66f5b8b3a5
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/** An on-the-wire transmittable message. */
+public interface Message extends Encodable {
+ /** Used to identify this request type. */
+ Type type();
+
+ /** An optional body for the message. */
+ ManagedBuffer body();
+
+ /** Whether to include the body of the message in the same frame as the message. */
+ boolean isBodyInFrame();
+
+ /** Preceding every serialized Message is its type, which allows us to deserialize it. */
+ public static enum Type implements Encodable {
+ ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
+ RpcRequest(3), RpcResponse(4), RpcFailure(5),
+ StreamRequest(6), StreamResponse(7), StreamFailure(8),
+ OneWayMessage(9), User(-1);
+
+ private final byte id;
+
+ private Type(int id) {
+ assert id < 128 : "Cannot have more than 128 message types";
+ this.id = (byte) id;
+ }
+
+ public byte id() { return id; }
+
+ @Override public int encodedLength() { return 1; }
+
+ @Override public void encode(ByteBuf buf) { buf.writeByte(id); }
+
+ public static Type decode(ByteBuf buf) {
+ byte id = buf.readByte();
+ switch (id) {
+ case 0: return ChunkFetchRequest;
+ case 1: return ChunkFetchSuccess;
+ case 2: return ChunkFetchFailure;
+ case 3: return RpcRequest;
+ case 4: return RpcResponse;
+ case 5: return RpcFailure;
+ case 6: return StreamRequest;
+ case 7: return StreamResponse;
+ case 8: return StreamFailure;
+ case 9: return OneWayMessage;
+ case -1: throw new IllegalArgumentException("User type messages cannot be decoded.");
+ default: throw new IllegalArgumentException("Unknown message type: " + id);
+ }
+ }
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
new file mode 100644
index 0000000000..074780f2b9
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageDecoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Decoder used by the client side to encode server-to-client responses.
+ * This encoder is stateless so it is safe to be shared by multiple threads.
+ */
+@ChannelHandler.Sharable
+public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
+
+ private final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);
+ @Override
+ public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
+ Message.Type msgType = Message.Type.decode(in);
+ Message decoded = decode(msgType, in);
+ assert decoded.type() == msgType;
+ logger.trace("Received message " + msgType + ": " + decoded);
+ out.add(decoded);
+ }
+
+ private Message decode(Message.Type msgType, ByteBuf in) {
+ switch (msgType) {
+ case ChunkFetchRequest:
+ return ChunkFetchRequest.decode(in);
+
+ case ChunkFetchSuccess:
+ return ChunkFetchSuccess.decode(in);
+
+ case ChunkFetchFailure:
+ return ChunkFetchFailure.decode(in);
+
+ case RpcRequest:
+ return RpcRequest.decode(in);
+
+ case RpcResponse:
+ return RpcResponse.decode(in);
+
+ case RpcFailure:
+ return RpcFailure.decode(in);
+
+ case OneWayMessage:
+ return OneWayMessage.decode(in);
+
+ case StreamRequest:
+ return StreamRequest.decode(in);
+
+ case StreamResponse:
+ return StreamResponse.decode(in);
+
+ case StreamFailure:
+ return StreamFailure.decode(in);
+
+ default:
+ throw new IllegalArgumentException("Unexpected message type: " + msgType);
+ }
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
new file mode 100644
index 0000000000..664df57fec
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageEncoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Encoder used by the server side to encode server-to-client responses.
+ * This encoder is stateless so it is safe to be shared by multiple threads.
+ */
+@ChannelHandler.Sharable
+public final class MessageEncoder extends MessageToMessageEncoder<Message> {
+
+ private final Logger logger = LoggerFactory.getLogger(MessageEncoder.class);
+
+ /***
+ * Encodes a Message by invoking its encode() method. For non-data messages, we will add one
+ * ByteBuf to 'out' containing the total frame length, the message type, and the message itself.
+ * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the
+ * data to 'out', in order to enable zero-copy transfer.
+ */
+ @Override
+ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) throws Exception {
+ Object body = null;
+ long bodyLength = 0;
+ boolean isBodyInFrame = false;
+
+ // If the message has a body, take it out to enable zero-copy transfer for the payload.
+ if (in.body() != null) {
+ try {
+ bodyLength = in.body().size();
+ body = in.body().convertToNetty();
+ isBodyInFrame = in.isBodyInFrame();
+ } catch (Exception e) {
+ in.body().release();
+ if (in instanceof AbstractResponseMessage) {
+ AbstractResponseMessage resp = (AbstractResponseMessage) in;
+ // Re-encode this message as a failure response.
+ String error = e.getMessage() != null ? e.getMessage() : "null";
+ logger.error(String.format("Error processing %s for client %s",
+ in, ctx.channel().remoteAddress()), e);
+ encode(ctx, resp.createFailureResponse(error), out);
+ } else {
+ throw e;
+ }
+ return;
+ }
+ }
+
+ Message.Type msgType = in.type();
+ // All messages have the frame length, message type, and message itself. The frame length
+ // may optionally include the length of the body data, depending on what message is being
+ // sent.
+ int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
+ long frameLength = headerLength + (isBodyInFrame ? bodyLength : 0);
+ ByteBuf header = ctx.alloc().heapBuffer(headerLength);
+ header.writeLong(frameLength);
+ msgType.encode(header);
+ in.encode(header);
+ assert header.writableBytes() == 0;
+
+ if (body != null) {
+ // We transfer ownership of the reference on in.body() to MessageWithHeader.
+ // This reference will be freed when MessageWithHeader.deallocate() is called.
+ out.add(new MessageWithHeader(in.body(), header, body, bodyLength));
+ } else {
+ out.add(header);
+ }
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
new file mode 100644
index 0000000000..66227f96a1
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.io.IOException;
+import java.nio.channels.WritableByteChannel;
+import javax.annotation.Nullable;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.FileRegion;
+import io.netty.util.AbstractReferenceCounted;
+import io.netty.util.ReferenceCountUtil;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * A wrapper message that holds two separate pieces (a header and a body).
+ *
+ * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion.
+ */
+class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
+
+ @Nullable private final ManagedBuffer managedBuffer;
+ private final ByteBuf header;
+ private final int headerLength;
+ private final Object body;
+ private final long bodyLength;
+ private long totalBytesTransferred;
+
+ /**
+ * Construct a new MessageWithHeader.
+ *
+ * @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to
+ * be passed in so that the buffer can be freed when this message is
+ * deallocated. Ownership of the caller's reference to this buffer is
+ * transferred to this class, so if the caller wants to continue to use the
+ * ManagedBuffer in other messages then they will need to call retain() on
+ * it before passing it to this constructor. This may be null if and only if
+ * `body` is a {@link FileRegion}.
+ * @param header the message header.
+ * @param body the message body. Must be either a {@link ByteBuf} or a {@link FileRegion}.
+ * @param bodyLength the length of the message body, in bytes.
+ */
+ MessageWithHeader(
+ @Nullable ManagedBuffer managedBuffer,
+ ByteBuf header,
+ Object body,
+ long bodyLength) {
+ Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion,
+ "Body must be a ByteBuf or a FileRegion.");
+ this.managedBuffer = managedBuffer;
+ this.header = header;
+ this.headerLength = header.readableBytes();
+ this.body = body;
+ this.bodyLength = bodyLength;
+ }
+
+ @Override
+ public long count() {
+ return headerLength + bodyLength;
+ }
+
+ @Override
+ public long position() {
+ return 0;
+ }
+
+ @Override
+ public long transfered() {
+ return totalBytesTransferred;
+ }
+
+ /**
+ * This code is more complicated than you would think because we might require multiple
+ * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting.
+ *
+ * The contract is that the caller will ensure position is properly set to the total number
+ * of bytes transferred so far (i.e. value returned by transfered()).
+ */
+ @Override
+ public long transferTo(final WritableByteChannel target, final long position) throws IOException {
+ Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position.");
+ // Bytes written for header in this call.
+ long writtenHeader = 0;
+ if (header.readableBytes() > 0) {
+ writtenHeader = copyByteBuf(header, target);
+ totalBytesTransferred += writtenHeader;
+ if (header.readableBytes() > 0) {
+ return writtenHeader;
+ }
+ }
+
+ // Bytes written for body in this call.
+ long writtenBody = 0;
+ if (body instanceof FileRegion) {
+ writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength);
+ } else if (body instanceof ByteBuf) {
+ writtenBody = copyByteBuf((ByteBuf) body, target);
+ }
+ totalBytesTransferred += writtenBody;
+
+ return writtenHeader + writtenBody;
+ }
+
+ @Override
+ protected void deallocate() {
+ header.release();
+ ReferenceCountUtil.release(body);
+ if (managedBuffer != null) {
+ managedBuffer.release();
+ }
+ }
+
+ private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException {
+ int written = target.write(buf.nioBuffer());
+ buf.skipBytes(written);
+ return written;
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
new file mode 100644
index 0000000000..efe0470f35
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * A RPC that does not expect a reply, which is handled by a remote
+ * {@link org.apache.spark.network.server.RpcHandler}.
+ */
+public final class OneWayMessage extends AbstractMessage implements RequestMessage {
+
+ public OneWayMessage(ManagedBuffer body) {
+ super(body, true);
+ }
+
+ @Override
+ public Type type() { return Type.OneWayMessage; }
+
+ @Override
+ public int encodedLength() {
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
+ }
+
+ public static OneWayMessage decode(ByteBuf buf) {
+ // See comment in encodedLength().
+ buf.readInt();
+ return new OneWayMessage(new NettyManagedBuffer(buf.retain()));
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(body());
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof OneWayMessage) {
+ OneWayMessage o = (OneWayMessage) other;
+ return super.equals(o);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("body", body())
+ .toString();
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java
new file mode 100644
index 0000000000..31b15bb17a
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import org.apache.spark.network.protocol.Message;
+
+/** Messages from the client to the server. */
+public interface RequestMessage extends Message {
+ // token interface
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java
new file mode 100644
index 0000000000..6edffd11cf
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import org.apache.spark.network.protocol.Message;
+
+/** Messages from the server to the client. */
+public interface ResponseMessage extends Message {
+ // token interface
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
new file mode 100644
index 0000000000..a76624ef5d
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/** Response to {@link RpcRequest} for a failed RPC. */
+public final class RpcFailure extends AbstractMessage implements ResponseMessage {
+ public final long requestId;
+ public final String errorString;
+
+ public RpcFailure(long requestId, String errorString) {
+ this.requestId = requestId;
+ this.errorString = errorString;
+ }
+
+ @Override
+ public Type type() { return Type.RpcFailure; }
+
+ @Override
+ public int encodedLength() {
+ return 8 + Encoders.Strings.encodedLength(errorString);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ Encoders.Strings.encode(buf, errorString);
+ }
+
+ public static RpcFailure decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ String errorString = Encoders.Strings.decode(buf);
+ return new RpcFailure(requestId, errorString);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(requestId, errorString);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcFailure) {
+ RpcFailure o = (RpcFailure) other;
+ return requestId == o.requestId && errorString.equals(o.errorString);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("errorString", errorString)
+ .toString();
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
new file mode 100644
index 0000000000..96213794a8
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}.
+ * This will correspond to a single
+ * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
+ */
+public final class RpcRequest extends AbstractMessage implements RequestMessage {
+ /** Used to link an RPC request with its response. */
+ public final long requestId;
+
+ public RpcRequest(long requestId, ManagedBuffer message) {
+ super(message, true);
+ this.requestId = requestId;
+ }
+
+ @Override
+ public Type type() { return Type.RpcRequest; }
+
+ @Override
+ public int encodedLength() {
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 8 + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
+ }
+
+ public static RpcRequest decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ // See comment in encodedLength().
+ buf.readInt();
+ return new RpcRequest(requestId, new NettyManagedBuffer(buf.retain()));
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(requestId, body());
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcRequest) {
+ RpcRequest o = (RpcRequest) other;
+ return requestId == o.requestId && super.equals(o);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("body", body())
+ .toString();
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
new file mode 100644
index 0000000000..bae866e14a
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/** Response to {@link RpcRequest} for a successful RPC. */
+public final class RpcResponse extends AbstractResponseMessage {
+ public final long requestId;
+
+ public RpcResponse(long requestId, ManagedBuffer message) {
+ super(message, true);
+ this.requestId = requestId;
+ }
+
+ @Override
+ public Type type() { return Type.RpcResponse; }
+
+ @Override
+ public int encodedLength() {
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 8 + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
+ }
+
+ @Override
+ public ResponseMessage createFailureResponse(String error) {
+ return new RpcFailure(requestId, error);
+ }
+
+ public static RpcResponse decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ // See comment in encodedLength().
+ buf.readInt();
+ return new RpcResponse(requestId, new NettyManagedBuffer(buf.retain()));
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(requestId, body());
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcResponse) {
+ RpcResponse o = (RpcResponse) other;
+ return requestId == o.requestId && super.equals(o);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("body", body())
+ .toString();
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java
new file mode 100644
index 0000000000..d46a263884
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+* Encapsulates a request for a particular chunk of a stream.
+*/
+public final class StreamChunkId implements Encodable {
+ public final long streamId;
+ public final int chunkIndex;
+
+ public StreamChunkId(long streamId, int chunkIndex) {
+ this.streamId = streamId;
+ this.chunkIndex = chunkIndex;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4;
+ }
+
+ public void encode(ByteBuf buffer) {
+ buffer.writeLong(streamId);
+ buffer.writeInt(chunkIndex);
+ }
+
+ public static StreamChunkId decode(ByteBuf buffer) {
+ assert buffer.readableBytes() >= 8 + 4;
+ long streamId = buffer.readLong();
+ int chunkIndex = buffer.readInt();
+ return new StreamChunkId(streamId, chunkIndex);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId, chunkIndex);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamChunkId) {
+ StreamChunkId o = (StreamChunkId) other;
+ return streamId == o.streamId && chunkIndex == o.chunkIndex;
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("chunkIndex", chunkIndex)
+ .toString();
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
new file mode 100644
index 0000000000..26747ee55b
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Message indicating an error when transferring a stream.
+ */
+public final class StreamFailure extends AbstractMessage implements ResponseMessage {
+ public final String streamId;
+ public final String error;
+
+ public StreamFailure(String streamId, String error) {
+ this.streamId = streamId;
+ this.error = error;
+ }
+
+ @Override
+ public Type type() { return Type.StreamFailure; }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(streamId) + Encoders.Strings.encodedLength(error);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, streamId);
+ Encoders.Strings.encode(buf, error);
+ }
+
+ public static StreamFailure decode(ByteBuf buf) {
+ String streamId = Encoders.Strings.decode(buf);
+ String error = Encoders.Strings.decode(buf);
+ return new StreamFailure(streamId, error);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId, error);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamFailure) {
+ StreamFailure o = (StreamFailure) other;
+ return streamId.equals(o.streamId) && error.equals(o.error);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("error", error)
+ .toString();
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
new file mode 100644
index 0000000000..35af5a84ba
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Request to stream data from the remote end.
+ * <p>
+ * The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before
+ * the data can be streamed.
+ */
+public final class StreamRequest extends AbstractMessage implements RequestMessage {
+ public final String streamId;
+
+ public StreamRequest(String streamId) {
+ this.streamId = streamId;
+ }
+
+ @Override
+ public Type type() { return Type.StreamRequest; }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(streamId);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, streamId);
+ }
+
+ public static StreamRequest decode(ByteBuf buf) {
+ String streamId = Encoders.Strings.decode(buf);
+ return new StreamRequest(streamId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamRequest) {
+ StreamRequest o = (StreamRequest) other;
+ return streamId.equals(o.streamId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .toString();
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
new file mode 100644
index 0000000000..51b899930f
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Response to {@link StreamRequest} when the stream has been successfully opened.
+ * <p>
+ * Note the message itself does not contain the stream data. That is written separately by the
+ * sender. The receiver is expected to set a temporary channel handler that will consume the
+ * number of bytes this message says the stream has.
+ */
+public final class StreamResponse extends AbstractResponseMessage {
+ public final String streamId;
+ public final long byteCount;
+
+ public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) {
+ super(buffer, false);
+ this.streamId = streamId;
+ this.byteCount = byteCount;
+ }
+
+ @Override
+ public Type type() { return Type.StreamResponse; }
+
+ @Override
+ public int encodedLength() {
+ return 8 + Encoders.Strings.encodedLength(streamId);
+ }
+
+ /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, streamId);
+ buf.writeLong(byteCount);
+ }
+
+ @Override
+ public ResponseMessage createFailureResponse(String error) {
+ return new StreamFailure(streamId, error);
+ }
+
+ public static StreamResponse decode(ByteBuf buf) {
+ String streamId = Encoders.Strings.decode(buf);
+ long byteCount = buf.readLong();
+ return new StreamResponse(streamId, byteCount, null);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(byteCount, streamId, body());
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamResponse) {
+ StreamResponse o = (StreamResponse) other;
+ return byteCount == o.byteCount && streamId.equals(o.streamId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("byteCount", byteCount)
+ .add("body", body())
+ .toString();
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
new file mode 100644
index 0000000000..68381037d6
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The
+ * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId.
+ */
+public class SaslClientBootstrap implements TransportClientBootstrap {
+ private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);
+
+ private final boolean encrypt;
+ private final TransportConf conf;
+ private final String appId;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
+ this(conf, appId, secretKeyHolder, false);
+ }
+
+ public SaslClientBootstrap(
+ TransportConf conf,
+ String appId,
+ SecretKeyHolder secretKeyHolder,
+ boolean encrypt) {
+ this.conf = conf;
+ this.appId = appId;
+ this.secretKeyHolder = secretKeyHolder;
+ this.encrypt = encrypt;
+ }
+
+ /**
+ * Performs SASL authentication by sending a token, and then proceeding with the SASL
+ * challenge-response tokens until we either successfully authenticate or throw an exception
+ * due to mismatch.
+ */
+ @Override
+ public void doBootstrap(TransportClient client, Channel channel) {
+ SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt);
+ try {
+ byte[] payload = saslClient.firstToken();
+
+ while (!saslClient.isComplete()) {
+ SaslMessage msg = new SaslMessage(appId, payload);
+ ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size());
+ msg.encode(buf);
+ buf.writeBytes(msg.body().nioByteBuffer());
+
+ ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs());
+ payload = saslClient.response(JavaUtils.bufferToArray(response));
+ }
+
+ client.setClientId(appId);
+
+ if (encrypt) {
+ if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
+ throw new RuntimeException(
+ new SaslException("Encryption requests by negotiated non-encrypted connection."));
+ }
+ SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
+ saslClient = null;
+ logger.debug("Channel {} configured for SASL encryption.", client);
+ }
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
+ } finally {
+ if (saslClient != null) {
+ try {
+ // Once authentication is complete, the server will trust all remaining communication.
+ saslClient.dispose();
+ } catch (RuntimeException e) {
+ logger.error("Error while disposing SASL client", e);
+ }
+ }
+ }
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java
new file mode 100644
index 0000000000..127335e4d3
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+import java.util.List;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPromise;
+import io.netty.channel.FileRegion;
+import io.netty.handler.codec.MessageToMessageDecoder;
+import io.netty.util.AbstractReferenceCounted;
+import io.netty.util.ReferenceCountUtil;
+
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * Provides SASL-based encription for transport channels. The single method exposed by this
+ * class installs the needed channel handlers on a connected channel.
+ */
+class SaslEncryption {
+
+ @VisibleForTesting
+ static final String ENCRYPTION_HANDLER_NAME = "saslEncryption";
+
+ /**
+ * Adds channel handlers that perform encryption / decryption of data using SASL.
+ *
+ * @param channel The channel.
+ * @param backend The SASL backend.
+ * @param maxOutboundBlockSize Max size in bytes of outgoing encrypted blocks, to control
+ * memory usage.
+ */
+ static void addToChannel(
+ Channel channel,
+ SaslEncryptionBackend backend,
+ int maxOutboundBlockSize) {
+ channel.pipeline()
+ .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize))
+ .addFirst("saslDecryption", new DecryptionHandler(backend))
+ .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder());
+ }
+
+ private static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
+
+ private final int maxOutboundBlockSize;
+ private final SaslEncryptionBackend backend;
+
+ EncryptionHandler(SaslEncryptionBackend backend, int maxOutboundBlockSize) {
+ this.backend = backend;
+ this.maxOutboundBlockSize = maxOutboundBlockSize;
+ }
+
+ /**
+ * Wrap the incoming message in an implementation that will perform encryption lazily. This is
+ * needed to guarantee ordering of the outgoing encrypted packets - they need to be decrypted in
+ * the same order, and netty doesn't have an atomic ChannelHandlerContext.write() API, so it
+ * does not guarantee any ordering.
+ */
+ @Override
+ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
+ throws Exception {
+
+ ctx.write(new EncryptedMessage(backend, msg, maxOutboundBlockSize), promise);
+ }
+
+ @Override
+ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
+ try {
+ backend.dispose();
+ } finally {
+ super.handlerRemoved(ctx);
+ }
+ }
+
+ }
+
+ private static class DecryptionHandler extends MessageToMessageDecoder<ByteBuf> {
+
+ private final SaslEncryptionBackend backend;
+
+ DecryptionHandler(SaslEncryptionBackend backend) {
+ this.backend = backend;
+ }
+
+ @Override
+ protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out)
+ throws Exception {
+
+ byte[] data;
+ int offset;
+ int length = msg.readableBytes();
+ if (msg.hasArray()) {
+ data = msg.array();
+ offset = msg.arrayOffset();
+ msg.skipBytes(length);
+ } else {
+ data = new byte[length];
+ msg.readBytes(data);
+ offset = 0;
+ }
+
+ out.add(Unpooled.wrappedBuffer(backend.unwrap(data, offset, length)));
+ }
+
+ }
+
+ @VisibleForTesting
+ static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion {
+
+ private final SaslEncryptionBackend backend;
+ private final boolean isByteBuf;
+ private final ByteBuf buf;
+ private final FileRegion region;
+
+ /**
+ * A channel used to buffer input data for encryption. The channel has an upper size bound
+ * so that if the input is larger than the allowed buffer, it will be broken into multiple
+ * chunks.
+ */
+ private final ByteArrayWritableChannel byteChannel;
+
+ private ByteBuf currentHeader;
+ private ByteBuffer currentChunk;
+ private long currentChunkSize;
+ private long currentReportedBytes;
+ private long unencryptedChunkSize;
+ private long transferred;
+
+ EncryptedMessage(SaslEncryptionBackend backend, Object msg, int maxOutboundBlockSize) {
+ Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion,
+ "Unrecognized message type: %s", msg.getClass().getName());
+ this.backend = backend;
+ this.isByteBuf = msg instanceof ByteBuf;
+ this.buf = isByteBuf ? (ByteBuf) msg : null;
+ this.region = isByteBuf ? null : (FileRegion) msg;
+ this.byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize);
+ }
+
+ /**
+ * Returns the size of the original (unencrypted) message.
+ *
+ * This makes assumptions about how netty treats FileRegion instances, because there's no way
+ * to know beforehand what will be the size of the encrypted message. Namely, it assumes
+ * that netty will try to transfer data from this message while
+ * <code>transfered() < count()</code>. So these two methods return, technically, wrong data,
+ * but netty doesn't know better.
+ */
+ @Override
+ public long count() {
+ return isByteBuf ? buf.readableBytes() : region.count();
+ }
+
+ @Override
+ public long position() {
+ return 0;
+ }
+
+ /**
+ * Returns an approximation of the amount of data transferred. See {@link #count()}.
+ */
+ @Override
+ public long transfered() {
+ return transferred;
+ }
+
+ /**
+ * Transfers data from the original message to the channel, encrypting it in the process.
+ *
+ * This method also breaks down the original message into smaller chunks when needed. This
+ * is done to keep memory usage under control. This avoids having to copy the whole message
+ * data into memory at once, and can avoid ballooning memory usage when transferring large
+ * messages such as shuffle blocks.
+ *
+ * The {@link #transfered()} counter also behaves a little funny, in that it won't go forward
+ * until a whole chunk has been written. This is done because the code can't use the actual
+ * number of bytes written to the channel as the transferred count (see {@link #count()}).
+ * Instead, once an encrypted chunk is written to the output (including its header), the
+ * size of the original block will be added to the {@link #transfered()} amount.
+ */
+ @Override
+ public long transferTo(final WritableByteChannel target, final long position)
+ throws IOException {
+
+ Preconditions.checkArgument(position == transfered(), "Invalid position.");
+
+ long reportedWritten = 0L;
+ long actuallyWritten = 0L;
+ do {
+ if (currentChunk == null) {
+ nextChunk();
+ }
+
+ if (currentHeader.readableBytes() > 0) {
+ int bytesWritten = target.write(currentHeader.nioBuffer());
+ currentHeader.skipBytes(bytesWritten);
+ actuallyWritten += bytesWritten;
+ if (currentHeader.readableBytes() > 0) {
+ // Break out of loop if there are still header bytes left to write.
+ break;
+ }
+ }
+
+ actuallyWritten += target.write(currentChunk);
+ if (!currentChunk.hasRemaining()) {
+ // Only update the count of written bytes once a full chunk has been written.
+ // See method javadoc.
+ long chunkBytesRemaining = unencryptedChunkSize - currentReportedBytes;
+ reportedWritten += chunkBytesRemaining;
+ transferred += chunkBytesRemaining;
+ currentHeader.release();
+ currentHeader = null;
+ currentChunk = null;
+ currentChunkSize = 0;
+ currentReportedBytes = 0;
+ }
+ } while (currentChunk == null && transfered() + reportedWritten < count());
+
+ // Returning 0 triggers a backoff mechanism in netty which may harm performance. Instead,
+ // we return 1 until we can (i.e. until the reported count would actually match the size
+ // of the current chunk), at which point we resort to returning 0 so that the counts still
+ // match, at the cost of some performance. That situation should be rare, though.
+ if (reportedWritten != 0L) {
+ return reportedWritten;
+ }
+
+ if (actuallyWritten > 0 && currentReportedBytes < currentChunkSize - 1) {
+ transferred += 1L;
+ currentReportedBytes += 1L;
+ return 1L;
+ }
+
+ return 0L;
+ }
+
+ private void nextChunk() throws IOException {
+ byteChannel.reset();
+ if (isByteBuf) {
+ int copied = byteChannel.write(buf.nioBuffer());
+ buf.skipBytes(copied);
+ } else {
+ region.transferTo(byteChannel, region.transfered());
+ }
+
+ byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length());
+ this.currentChunk = ByteBuffer.wrap(encrypted);
+ this.currentChunkSize = encrypted.length;
+ this.currentHeader = Unpooled.copyLong(8 + currentChunkSize);
+ this.unencryptedChunkSize = byteChannel.length();
+ }
+
+ @Override
+ protected void deallocate() {
+ if (currentHeader != null) {
+ currentHeader.release();
+ }
+ if (buf != null) {
+ buf.release();
+ }
+ if (region != null) {
+ region.release();
+ }
+ }
+
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java
new file mode 100644
index 0000000000..89b78bc7e1
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.sasl.SaslException;
+
+interface SaslEncryptionBackend {
+
+ /** Disposes of resources used by the backend. */
+ void dispose();
+
+ /** Encrypt data. */
+ byte[] wrap(byte[] data, int offset, int len) throws SaslException;
+
+ /** Decrypt data. */
+ byte[] unwrap(byte[] data, int offset, int len) throws SaslException;
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
new file mode 100644
index 0000000000..e52b526f09
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+import org.apache.spark.network.protocol.Encoders;
+import org.apache.spark.network.protocol.AbstractMessage;
+
+/**
+ * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
+ * with the given appId. This appId allows a single SaslRpcHandler to multiplex different
+ * applications which may be using different sets of credentials.
+ */
+class SaslMessage extends AbstractMessage {
+
+ /** Serialization tag used to catch incorrect payloads. */
+ private static final byte TAG_BYTE = (byte) 0xEA;
+
+ public final String appId;
+
+ public SaslMessage(String appId, byte[] message) {
+ this(appId, Unpooled.wrappedBuffer(message));
+ }
+
+ public SaslMessage(String appId, ByteBuf message) {
+ super(new NettyManagedBuffer(message), true);
+ this.appId = appId;
+ }
+
+ @Override
+ public Type type() { return Type.User; }
+
+ @Override
+ public int encodedLength() {
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 1 + Encoders.Strings.encodedLength(appId) + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeByte(TAG_BYTE);
+ Encoders.Strings.encode(buf, appId);
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
+ }
+
+ public static SaslMessage decode(ByteBuf buf) {
+ if (buf.readByte() != TAG_BYTE) {
+ throw new IllegalStateException("Expected SaslMessage, received something else"
+ + " (maybe your client does not have SASL enabled?)");
+ }
+
+ String appId = Encoders.Strings.decode(buf);
+ // See comment in encodedLength().
+ buf.readInt();
+ return new SaslMessage(appId, buf.retain());
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
new file mode 100644
index 0000000000..c41f5b6873
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import javax.security.sasl.Sasl;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * RPC Handler which performs SASL authentication before delegating to a child RPC handler.
+ * The delegate will only receive messages if the given connection has been successfully
+ * authenticated. A connection may be authenticated at most once.
+ *
+ * Note that the authentication process consists of multiple challenge-response pairs, each of
+ * which are individual RPCs.
+ */
+class SaslRpcHandler extends RpcHandler {
+ private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
+
+ /** Transport configuration. */
+ private final TransportConf conf;
+
+ /** The client channel. */
+ private final Channel channel;
+
+ /** RpcHandler we will delegate to for authenticated connections. */
+ private final RpcHandler delegate;
+
+ /** Class which provides secret keys which are shared by server and client on a per-app basis. */
+ private final SecretKeyHolder secretKeyHolder;
+
+ private SparkSaslServer saslServer;
+ private boolean isComplete;
+
+ SaslRpcHandler(
+ TransportConf conf,
+ Channel channel,
+ RpcHandler delegate,
+ SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.channel = channel;
+ this.delegate = delegate;
+ this.secretKeyHolder = secretKeyHolder;
+ this.saslServer = null;
+ this.isComplete = false;
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
+ if (isComplete) {
+ // Authentication complete, delegate to base handler.
+ delegate.receive(client, message, callback);
+ return;
+ }
+
+ ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
+ SaslMessage saslMessage;
+ try {
+ saslMessage = SaslMessage.decode(nettyBuf);
+ } finally {
+ nettyBuf.release();
+ }
+
+ if (saslServer == null) {
+ // First message in the handshake, setup the necessary state.
+ client.setClientId(saslMessage.appId);
+ saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
+ conf.saslServerAlwaysEncrypt());
+ }
+
+ byte[] response;
+ try {
+ response = saslServer.response(JavaUtils.bufferToArray(
+ saslMessage.body().nioByteBuffer()));
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
+ }
+ callback.onSuccess(ByteBuffer.wrap(response));
+
+ // Setup encryption after the SASL response is sent, otherwise the client can't parse the
+ // response. It's ok to change the channel pipeline here since we are processing an incoming
+ // message, so the pipeline is busy and no new incoming messages will be fed to it before this
+ // method returns. This assumes that the code ensures, through other means, that no outbound
+ // messages are being written to the channel while negotiation is still going on.
+ if (saslServer.isComplete()) {
+ logger.debug("SASL authentication successful for channel {}", client);
+ isComplete = true;
+ if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
+ logger.debug("Enabling encryption for channel {}", client);
+ SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
+ saslServer = null;
+ } else {
+ saslServer.dispose();
+ saslServer = null;
+ }
+ }
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message) {
+ delegate.receive(client, message);
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return delegate.getStreamManager();
+ }
+
+ @Override
+ public void channelActive(TransportClient client) {
+ delegate.channelActive(client);
+ }
+
+ @Override
+ public void channelInactive(TransportClient client) {
+ try {
+ delegate.channelInactive(client);
+ } finally {
+ if (saslServer != null) {
+ saslServer.dispose();
+ }
+ }
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause, TransportClient client) {
+ delegate.exceptionCaught(cause, client);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java
new file mode 100644
index 0000000000..f2f983856f
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import io.netty.channel.Channel;
+
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a client connects
+ * to the server. This allows customizing the client channel to allow for things such as SASL
+ * authentication.
+ */
+public class SaslServerBootstrap implements TransportServerBootstrap {
+
+ private final TransportConf conf;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public SaslServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ /**
+ * Wrap the given application handler in a SaslRpcHandler that will handle the initial SASL
+ * negotiation.
+ */
+ public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
+ return new SaslRpcHandler(conf, channel, rpcHandler, secretKeyHolder);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
new file mode 100644
index 0000000000..81d5766794
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+/**
+ * Interface for getting a secret key associated with some application.
+ */
+public interface SecretKeyHolder {
+ /**
+ * Gets an appropriate SASL User for the given appId.
+ * @throws IllegalArgumentException if the given appId is not associated with a SASL user.
+ */
+ String getSaslUser(String appId);
+
+ /**
+ * Gets an appropriate SASL secret key for the given appId.
+ * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key.
+ */
+ String getSecretKey(String appId);
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
new file mode 100644
index 0000000000..94685e91b8
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+import java.util.Map;
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.RealmChoiceCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+
+import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.spark.network.sasl.SparkSaslServer.*;
+
+/**
+ * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the
+ * initial state to the "authenticated" state. This client initializes the protocol via a
+ * firstToken, which is then followed by a set of challenges and responses.
+ */
+public class SparkSaslClient implements SaslEncryptionBackend {
+ private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class);
+
+ private final String secretKeyId;
+ private final SecretKeyHolder secretKeyHolder;
+ private final String expectedQop;
+ private SaslClient saslClient;
+
+ public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, boolean encrypt) {
+ this.secretKeyId = secretKeyId;
+ this.secretKeyHolder = secretKeyHolder;
+ this.expectedQop = encrypt ? QOP_AUTH_CONF : QOP_AUTH;
+
+ Map<String, String> saslProps = ImmutableMap.<String, String>builder()
+ .put(Sasl.QOP, expectedQop)
+ .build();
+ try {
+ this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM,
+ saslProps, new ClientCallbackHandler());
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /** Used to initiate SASL handshake with server. */
+ public synchronized byte[] firstToken() {
+ if (saslClient != null && saslClient.hasInitialResponse()) {
+ try {
+ return saslClient.evaluateChallenge(new byte[0]);
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ } else {
+ return new byte[0];
+ }
+ }
+
+ /** Determines whether the authentication exchange has completed. */
+ public synchronized boolean isComplete() {
+ return saslClient != null && saslClient.isComplete();
+ }
+
+ /** Returns the value of a negotiated property. */
+ public Object getNegotiatedProperty(String name) {
+ return saslClient.getNegotiatedProperty(name);
+ }
+
+ /**
+ * Respond to server's SASL token.
+ * @param token contains server's SASL token
+ * @return client's response SASL token
+ */
+ public synchronized byte[] response(byte[] token) {
+ try {
+ return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0];
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslClient might be using.
+ */
+ @Override
+ public synchronized void dispose() {
+ if (saslClient != null) {
+ try {
+ saslClient.dispose();
+ } catch (SaslException e) {
+ // ignore
+ } finally {
+ saslClient = null;
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * that works with share secrets.
+ */
+ private class ClientCallbackHandler implements CallbackHandler {
+ @Override
+ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+
+ for (Callback callback : callbacks) {
+ if (callback instanceof NameCallback) {
+ logger.trace("SASL client callback: setting username");
+ NameCallback nc = (NameCallback) callback;
+ nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
+ } else if (callback instanceof PasswordCallback) {
+ logger.trace("SASL client callback: setting password");
+ PasswordCallback pc = (PasswordCallback) callback;
+ pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
+ } else if (callback instanceof RealmCallback) {
+ logger.trace("SASL client callback: setting realm");
+ RealmCallback rc = (RealmCallback) callback;
+ rc.setText(rc.getDefaultText());
+ } else if (callback instanceof RealmChoiceCallback) {
+ // ignore (?)
+ } else {
+ throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
+ }
+ }
+ }
+ }
+
+ @Override
+ public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
+ return saslClient.wrap(data, offset, len);
+ }
+
+ @Override
+ public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
+ return saslClient.unwrap(data, offset, len);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
new file mode 100644
index 0000000000..431cb67a2a
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.AuthorizeCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+import java.io.IOException;
+import java.util.Map;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableMap;
+import io.netty.buffer.Unpooled;
+import io.netty.handler.codec.base64.Base64;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the
+ * initial state to the "authenticated" state. (It is not a server in the sense of accepting
+ * connections on some socket.)
+ */
+public class SparkSaslServer implements SaslEncryptionBackend {
+ private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class);
+
+ /**
+ * This is passed as the server name when creating the sasl client/server.
+ * This could be changed to be configurable in the future.
+ */
+ static final String DEFAULT_REALM = "default";
+
+ /**
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ static final String DIGEST = "DIGEST-MD5";
+
+ /**
+ * Quality of protection value that includes encryption.
+ */
+ static final String QOP_AUTH_CONF = "auth-conf";
+
+ /**
+ * Quality of protection value that does not include encryption.
+ */
+ static final String QOP_AUTH = "auth";
+
+ /** Identifier for a certain secret key within the secretKeyHolder. */
+ private final String secretKeyId;
+ private final SecretKeyHolder secretKeyHolder;
+ private SaslServer saslServer;
+
+ public SparkSaslServer(
+ String secretKeyId,
+ SecretKeyHolder secretKeyHolder,
+ boolean alwaysEncrypt) {
+ this.secretKeyId = secretKeyId;
+ this.secretKeyHolder = secretKeyHolder;
+
+ // Sasl.QOP is a comma-separated list of supported values. The value that allows encryption
+ // is listed first since it's preferred over the non-encrypted one (if the client also
+ // lists both in the request).
+ String qop = alwaysEncrypt ? QOP_AUTH_CONF : String.format("%s,%s", QOP_AUTH_CONF, QOP_AUTH);
+ Map<String, String> saslProps = ImmutableMap.<String, String>builder()
+ .put(Sasl.SERVER_AUTH, "true")
+ .put(Sasl.QOP, qop)
+ .build();
+ try {
+ this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, saslProps,
+ new DigestCallbackHandler());
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Determines whether the authentication exchange has completed successfully.
+ */
+ public synchronized boolean isComplete() {
+ return saslServer != null && saslServer.isComplete();
+ }
+
+ /** Returns the value of a negotiated property. */
+ public Object getNegotiatedProperty(String name) {
+ return saslServer.getNegotiatedProperty(name);
+ }
+
+ /**
+ * Used to respond to server SASL tokens.
+ * @param token Server's SASL token
+ * @return response to send back to the server.
+ */
+ public synchronized byte[] response(byte[] token) {
+ try {
+ return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0];
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslServer might be using.
+ */
+ @Override
+ public synchronized void dispose() {
+ if (saslServer != null) {
+ try {
+ saslServer.dispose();
+ } catch (SaslException e) {
+ // ignore
+ } finally {
+ saslServer = null;
+ }
+ }
+ }
+
+ @Override
+ public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
+ return saslServer.wrap(data, offset, len);
+ }
+
+ @Override
+ public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
+ return saslServer.unwrap(data, offset, len);
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism.
+ */
+ private class DigestCallbackHandler implements CallbackHandler {
+ @Override
+ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+ for (Callback callback : callbacks) {
+ if (callback instanceof NameCallback) {
+ logger.trace("SASL server callback: setting username");
+ NameCallback nc = (NameCallback) callback;
+ nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
+ } else if (callback instanceof PasswordCallback) {
+ logger.trace("SASL server callback: setting password");
+ PasswordCallback pc = (PasswordCallback) callback;
+ pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
+ } else if (callback instanceof RealmCallback) {
+ logger.trace("SASL server callback: setting realm");
+ RealmCallback rc = (RealmCallback) callback;
+ rc.setText(rc.getDefaultText());
+ } else if (callback instanceof AuthorizeCallback) {
+ AuthorizeCallback ac = (AuthorizeCallback) callback;
+ String authId = ac.getAuthenticationID();
+ String authzId = ac.getAuthorizationID();
+ ac.setAuthorized(authId.equals(authzId));
+ if (ac.isAuthorized()) {
+ ac.setAuthorizedID(authzId);
+ }
+ logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized());
+ } else {
+ throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
+ }
+ }
+ }
+ }
+
+ /* Encode a byte[] identifier as a Base64-encoded string. */
+ public static String encodeIdentifier(String identifier) {
+ Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled");
+ return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8)))
+ .toString(Charsets.UTF_8);
+ }
+
+ /** Encode a password as a base64-encoded char[] array. */
+ public static char[] encodePassword(String password) {
+ Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled");
+ return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8)))
+ .toString(Charsets.UTF_8).toCharArray();
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java
new file mode 100644
index 0000000000..4a1f28e9ff
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import org.apache.spark.network.protocol.Message;
+
+/**
+ * Handles either request or response messages coming off of Netty. A MessageHandler instance
+ * is associated with a single Netty Channel (though it may have multiple clients on the same
+ * Channel.)
+ */
+public abstract class MessageHandler<T extends Message> {
+ /** Handles the receipt of a single message. */
+ public abstract void handle(T message) throws Exception;
+
+ /** Invoked when the channel this MessageHandler is on is active. */
+ public abstract void channelActive();
+
+ /** Invoked when an exception was caught on the Channel. */
+ public abstract void exceptionCaught(Throwable cause);
+
+ /** Invoked when the channel this MessageHandler is on is inactive. */
+ public abstract void channelInactive();
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
new file mode 100644
index 0000000000..6ed61da5c7
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.nio.ByteBuffer;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+
+/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */
+public class NoOpRpcHandler extends RpcHandler {
+ private final StreamManager streamManager;
+
+ public NoOpRpcHandler() {
+ streamManager = new OneForOneStreamManager();
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
+ throw new UnsupportedOperationException("Cannot handle messages");
+ }
+
+ @Override
+ public StreamManager getStreamManager() { return streamManager; }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
new file mode 100644
index 0000000000..ea9e735e0a
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+
+import com.google.common.base.Preconditions;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.TransportClient;
+
+/**
+ * StreamManager which allows registration of an Iterator&lt;ManagedBuffer&gt;, which are individually
+ * fetched as chunks by the client. Each registered buffer is one chunk.
+ */
+public class OneForOneStreamManager extends StreamManager {
+ private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class);
+
+ private final AtomicLong nextStreamId;
+ private final ConcurrentHashMap<Long, StreamState> streams;
+
+ /** State of a single stream. */
+ private static class StreamState {
+ final String appId;
+ final Iterator<ManagedBuffer> buffers;
+
+ // The channel associated to the stream
+ Channel associatedChannel = null;
+
+ // Used to keep track of the index of the buffer that the user has retrieved, just to ensure
+ // that the caller only requests each chunk one at a time, in order.
+ int curChunk = 0;
+
+ StreamState(String appId, Iterator<ManagedBuffer> buffers) {
+ this.appId = appId;
+ this.buffers = Preconditions.checkNotNull(buffers);
+ }
+ }
+
+ public OneForOneStreamManager() {
+ // For debugging purposes, start with a random stream id to help identifying different streams.
+ // This does not need to be globally unique, only unique to this class.
+ nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
+ streams = new ConcurrentHashMap<Long, StreamState>();
+ }
+
+ @Override
+ public void registerChannel(Channel channel, long streamId) {
+ if (streams.containsKey(streamId)) {
+ streams.get(streamId).associatedChannel = channel;
+ }
+ }
+
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ StreamState state = streams.get(streamId);
+ if (chunkIndex != state.curChunk) {
+ throw new IllegalStateException(String.format(
+ "Received out-of-order chunk index %s (expected %s)", chunkIndex, state.curChunk));
+ } else if (!state.buffers.hasNext()) {
+ throw new IllegalStateException(String.format(
+ "Requested chunk index beyond end %s", chunkIndex));
+ }
+ state.curChunk += 1;
+ ManagedBuffer nextChunk = state.buffers.next();
+
+ if (!state.buffers.hasNext()) {
+ logger.trace("Removing stream id {}", streamId);
+ streams.remove(streamId);
+ }
+
+ return nextChunk;
+ }
+
+ @Override
+ public void connectionTerminated(Channel channel) {
+ // Close all streams which have been associated with the channel.
+ for (Map.Entry<Long, StreamState> entry: streams.entrySet()) {
+ StreamState state = entry.getValue();
+ if (state.associatedChannel == channel) {
+ streams.remove(entry.getKey());
+
+ // Release all remaining buffers.
+ while (state.buffers.hasNext()) {
+ state.buffers.next().release();
+ }
+ }
+ }
+ }
+
+ @Override
+ public void checkAuthorization(TransportClient client, long streamId) {
+ if (client.getClientId() != null) {
+ StreamState state = streams.get(streamId);
+ Preconditions.checkArgument(state != null, "Unknown stream ID.");
+ if (!client.getClientId().equals(state.appId)) {
+ throw new SecurityException(String.format(
+ "Client %s not authorized to read stream %d (app %s).",
+ client.getClientId(),
+ streamId,
+ state.appId));
+ }
+ }
+ }
+
+ /**
+ * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
+ * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a
+ * client connection is closed before the iterator is fully drained, then the remaining buffers
+ * will all be release()'d.
+ *
+ * If an app ID is provided, only callers who've authenticated with the given app ID will be
+ * allowed to fetch from this stream.
+ */
+ public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
+ long myStreamId = nextStreamId.getAndIncrement();
+ streams.put(myStreamId, new StreamState(appId, buffers));
+ return myStreamId;
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java
new file mode 100644
index 0000000000..a99c3015b0
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.nio.ByteBuffer;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+
+/**
+ * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
+ */
+public abstract class RpcHandler {
+
+ private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback();
+
+ /**
+ * Receive a single RPC message. Any exception thrown while in this method will be sent back to
+ * the client in string form as a standard RPC failure.
+ *
+ * This method will not be called in parallel for a single TransportClient (i.e., channel).
+ *
+ * @param client A channel client which enables the handler to make requests back to the sender
+ * of this RPC. This will always be the exact same object for a particular channel.
+ * @param message The serialized bytes of the RPC.
+ * @param callback Callback which should be invoked exactly once upon success or failure of the
+ * RPC.
+ */
+ public abstract void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback);
+
+ /**
+ * Returns the StreamManager which contains the state about which streams are currently being
+ * fetched by a TransportClient.
+ */
+ public abstract StreamManager getStreamManager();
+
+ /**
+ * Receives an RPC message that does not expect a reply. The default implementation will
+ * call "{@link #receive(TransportClient, ByteBuffer, RpcResponseCallback)}" and log a warning if
+ * any of the callback methods are called.
+ *
+ * @param client A channel client which enables the handler to make requests back to the sender
+ * of this RPC. This will always be the exact same object for a particular channel.
+ * @param message The serialized bytes of the RPC.
+ */
+ public void receive(TransportClient client, ByteBuffer message) {
+ receive(client, message, ONE_WAY_CALLBACK);
+ }
+
+ /**
+ * Invoked when the channel associated with the given client is active.
+ */
+ public void channelActive(TransportClient client) { }
+
+ /**
+ * Invoked when the channel associated with the given client is inactive.
+ * No further requests will come from this client.
+ */
+ public void channelInactive(TransportClient client) { }
+
+ public void exceptionCaught(Throwable cause, TransportClient client) { }
+
+ private static class OneWayRpcCallback implements RpcResponseCallback {
+
+ private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class);
+
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ logger.warn("Response provided for one-way RPC.");
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ logger.error("Error response provided for one-way RPC.", e);
+ }
+
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
new file mode 100644
index 0000000000..07f161a29c
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import io.netty.channel.Channel;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.TransportClient;
+
+/**
+ * The StreamManager is used to fetch individual chunks from a stream. This is used in
+ * {@link TransportRequestHandler} in order to respond to fetchChunk() requests. Creation of the
+ * stream is outside the scope of the transport layer, but a given stream is guaranteed to be read
+ * by only one client connection, meaning that getChunk() for a particular stream will be called
+ * serially and that once the connection associated with the stream is closed, that stream will
+ * never be used again.
+ */
+public abstract class StreamManager {
+ /**
+ * Called in response to a fetchChunk() request. The returned buffer will be passed as-is to the
+ * client. A single stream will be associated with a single TCP connection, so this method
+ * will not be called in parallel for a particular stream.
+ *
+ * Chunks may be requested in any order, and requests may be repeated, but it is not required
+ * that implementations support this behavior.
+ *
+ * The returned ManagedBuffer will be release()'d after being written to the network.
+ *
+ * @param streamId id of a stream that has been previously registered with the StreamManager.
+ * @param chunkIndex 0-indexed chunk of the stream that's requested
+ */
+ public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);
+
+ /**
+ * Called in response to a stream() request. The returned data is streamed to the client
+ * through a single TCP connection.
+ *
+ * Note the <code>streamId</code> argument is not related to the similarly named argument in the
+ * {@link #getChunk(long, int)} method.
+ *
+ * @param streamId id of a stream that has been previously registered with the StreamManager.
+ * @return A managed buffer for the stream, or null if the stream was not found.
+ */
+ public ManagedBuffer openStream(String streamId) {
+ throw new UnsupportedOperationException();
+ }
+
+ /**
+ * Associates a stream with a single client connection, which is guaranteed to be the only reader
+ * of the stream. The getChunk() method will be called serially on this connection and once the
+ * connection is closed, the stream will never be used again, enabling cleanup.
+ *
+ * This must be called before the first getChunk() on the stream, but it may be invoked multiple
+ * times with the same channel and stream id.
+ */
+ public void registerChannel(Channel channel, long streamId) { }
+
+ /**
+ * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
+ * to read from the associated streams again, so any state can be cleaned up.
+ */
+ public void connectionTerminated(Channel channel) { }
+
+ /**
+ * Verify that the client is authorized to read from the given stream.
+ *
+ * @throws SecurityException If client is not authorized.
+ */
+ public void checkAuthorization(TransportClient client, long streamId) { }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
new file mode 100644
index 0000000000..18a9b7887e
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.SimpleChannelInboundHandler;
+import io.netty.handler.timeout.IdleState;
+import io.netty.handler.timeout.IdleStateEvent;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.Message;
+import org.apache.spark.network.protocol.RequestMessage;
+import org.apache.spark.network.protocol.ResponseMessage;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * The single Transport-level Channel handler which is used for delegating requests to the
+ * {@link TransportRequestHandler} and responses to the {@link TransportResponseHandler}.
+ *
+ * All channels created in the transport layer are bidirectional. When the Client initiates a Netty
+ * Channel with a RequestMessage (which gets handled by the Server's RequestHandler), the Server
+ * will produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server
+ * also gets a handle on the same Channel, so it may then begin to send RequestMessages to the
+ * Client.
+ * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler,
+ * for the Client's responses to the Server's requests.
+ *
+ * This class also handles timeouts from a {@link io.netty.handler.timeout.IdleStateHandler}.
+ * We consider a connection timed out if there are outstanding fetch or RPC requests but no traffic
+ * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not
+ * timeout if the client is continuously sending but getting no responses, for simplicity.
+ */
+public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {
+ private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);
+
+ private final TransportClient client;
+ private final TransportResponseHandler responseHandler;
+ private final TransportRequestHandler requestHandler;
+ private final long requestTimeoutNs;
+ private final boolean closeIdleConnections;
+
+ public TransportChannelHandler(
+ TransportClient client,
+ TransportResponseHandler responseHandler,
+ TransportRequestHandler requestHandler,
+ long requestTimeoutMs,
+ boolean closeIdleConnections) {
+ this.client = client;
+ this.responseHandler = responseHandler;
+ this.requestHandler = requestHandler;
+ this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000;
+ this.closeIdleConnections = closeIdleConnections;
+ }
+
+ public TransportClient getClient() {
+ return client;
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
+ logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()),
+ cause);
+ requestHandler.exceptionCaught(cause);
+ responseHandler.exceptionCaught(cause);
+ ctx.close();
+ }
+
+ @Override
+ public void channelActive(ChannelHandlerContext ctx) throws Exception {
+ try {
+ requestHandler.channelActive();
+ } catch (RuntimeException e) {
+ logger.error("Exception from request handler while registering channel", e);
+ }
+ try {
+ responseHandler.channelActive();
+ } catch (RuntimeException e) {
+ logger.error("Exception from response handler while registering channel", e);
+ }
+ super.channelRegistered(ctx);
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ try {
+ requestHandler.channelInactive();
+ } catch (RuntimeException e) {
+ logger.error("Exception from request handler while unregistering channel", e);
+ }
+ try {
+ responseHandler.channelInactive();
+ } catch (RuntimeException e) {
+ logger.error("Exception from response handler while unregistering channel", e);
+ }
+ super.channelUnregistered(ctx);
+ }
+
+ @Override
+ public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
+ if (request instanceof RequestMessage) {
+ requestHandler.handle((RequestMessage) request);
+ } else {
+ responseHandler.handle((ResponseMessage) request);
+ }
+ }
+
+ /** Triggered based on events from an {@link io.netty.handler.timeout.IdleStateHandler}. */
+ @Override
+ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
+ if (evt instanceof IdleStateEvent) {
+ IdleStateEvent e = (IdleStateEvent) evt;
+ // See class comment for timeout semantics. In addition to ensuring we only timeout while
+ // there are outstanding requests, we also do a secondary consistency check to ensure
+ // there's no race between the idle timeout and incrementing the numOutstandingRequests
+ // (see SPARK-7003).
+ //
+ // To avoid a race between TransportClientFactory.createClient() and this code which could
+ // result in an inactive client being returned, this needs to run in a synchronized block.
+ synchronized (this) {
+ boolean isActuallyOverdue =
+ System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs;
+ if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) {
+ if (responseHandler.numOutstandingRequests() > 0) {
+ String address = NettyUtils.getRemoteAddress(ctx.channel());
+ logger.error("Connection to {} has been quiet for {} ms while there are outstanding " +
+ "requests. Assuming connection is dead; please adjust spark.network.timeout if this " +
+ "is wrong.", address, requestTimeoutNs / 1000 / 1000);
+ client.timeOut();
+ ctx.close();
+ } else if (closeIdleConnections) {
+ // While CloseIdleConnections is enable, we also close idle connection
+ client.timeOut();
+ ctx.close();
+ }
+ }
+ }
+ }
+ ctx.fireUserEventTriggered(evt);
+ }
+
+ public TransportResponseHandler getResponseHandler() {
+ return responseHandler;
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
new file mode 100644
index 0000000000..296ced3db0
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.OneWayMessage;
+import org.apache.spark.network.protocol.RequestMessage;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamRequest;
+import org.apache.spark.network.protocol.StreamResponse;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * A handler that processes requests from clients and writes chunk data back. Each handler is
+ * attached to a single Netty channel, and keeps track of which streams have been fetched via this
+ * channel, in order to clean them up if the channel is terminated (see #channelUnregistered).
+ *
+ * The messages should have been processed by the pipeline setup by {@link TransportServer}.
+ */
+public class TransportRequestHandler extends MessageHandler<RequestMessage> {
+ private final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class);
+
+ /** The Netty channel that this handler is associated with. */
+ private final Channel channel;
+
+ /** Client on the same channel allowing us to talk back to the requester. */
+ private final TransportClient reverseClient;
+
+ /** Handles all RPC messages. */
+ private final RpcHandler rpcHandler;
+
+ /** Returns each chunk part of a stream. */
+ private final StreamManager streamManager;
+
+ public TransportRequestHandler(
+ Channel channel,
+ TransportClient reverseClient,
+ RpcHandler rpcHandler) {
+ this.channel = channel;
+ this.reverseClient = reverseClient;
+ this.rpcHandler = rpcHandler;
+ this.streamManager = rpcHandler.getStreamManager();
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) {
+ rpcHandler.exceptionCaught(cause, reverseClient);
+ }
+
+ @Override
+ public void channelActive() {
+ rpcHandler.channelActive(reverseClient);
+ }
+
+ @Override
+ public void channelInactive() {
+ if (streamManager != null) {
+ try {
+ streamManager.connectionTerminated(channel);
+ } catch (RuntimeException e) {
+ logger.error("StreamManager connectionTerminated() callback failed.", e);
+ }
+ }
+ rpcHandler.channelInactive(reverseClient);
+ }
+
+ @Override
+ public void handle(RequestMessage request) {
+ if (request instanceof ChunkFetchRequest) {
+ processFetchRequest((ChunkFetchRequest) request);
+ } else if (request instanceof RpcRequest) {
+ processRpcRequest((RpcRequest) request);
+ } else if (request instanceof OneWayMessage) {
+ processOneWayMessage((OneWayMessage) request);
+ } else if (request instanceof StreamRequest) {
+ processStreamRequest((StreamRequest) request);
+ } else {
+ throw new IllegalArgumentException("Unknown request type: " + request);
+ }
+ }
+
+ private void processFetchRequest(final ChunkFetchRequest req) {
+ final String client = NettyUtils.getRemoteAddress(channel);
+
+ logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId);
+
+ ManagedBuffer buf;
+ try {
+ streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
+ streamManager.registerChannel(channel, req.streamChunkId.streamId);
+ buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
+ } catch (Exception e) {
+ logger.error(String.format(
+ "Error opening block %s for request from %s", req.streamChunkId, client), e);
+ respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e)));
+ return;
+ }
+
+ respond(new ChunkFetchSuccess(req.streamChunkId, buf));
+ }
+
+ private void processStreamRequest(final StreamRequest req) {
+ final String client = NettyUtils.getRemoteAddress(channel);
+ ManagedBuffer buf;
+ try {
+ buf = streamManager.openStream(req.streamId);
+ } catch (Exception e) {
+ logger.error(String.format(
+ "Error opening stream %s for request from %s", req.streamId, client), e);
+ respond(new StreamFailure(req.streamId, Throwables.getStackTraceAsString(e)));
+ return;
+ }
+
+ if (buf != null) {
+ respond(new StreamResponse(req.streamId, buf.size(), buf));
+ } else {
+ respond(new StreamFailure(req.streamId, String.format(
+ "Stream '%s' was not found.", req.streamId)));
+ }
+ }
+
+ private void processRpcRequest(final RpcRequest req) {
+ try {
+ rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() {
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
+ }
+ });
+ } catch (Exception e) {
+ logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
+ respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
+ } finally {
+ req.body().release();
+ }
+ }
+
+ private void processOneWayMessage(OneWayMessage req) {
+ try {
+ rpcHandler.receive(reverseClient, req.body().nioByteBuffer());
+ } catch (Exception e) {
+ logger.error("Error while invoking RpcHandler#receive() for one-way message.", e);
+ } finally {
+ req.body().release();
+ }
+ }
+
+ /**
+ * Responds to a single message with some Encodable object. If a failure occurs while sending,
+ * it will be logged and the channel closed.
+ */
+ private void respond(final Encodable result) {
+ final String remoteAddress = channel.remoteAddress().toString();
+ channel.writeAndFlush(result).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ logger.trace(String.format("Sent result %s to client %s", result, remoteAddress));
+ } else {
+ logger.error(String.format("Error sending result %s to %s; closing connection",
+ result, remoteAddress), future.cause());
+ channel.close();
+ }
+ }
+ }
+ );
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java
new file mode 100644
index 0000000000..baae235e02
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.io.Closeable;
+import java.net.InetSocketAddress;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import org.apache.spark.network.util.JavaUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Server for the efficient, low-level streaming service.
+ */
+public class TransportServer implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(TransportServer.class);
+
+ private final TransportContext context;
+ private final TransportConf conf;
+ private final RpcHandler appRpcHandler;
+ private final List<TransportServerBootstrap> bootstraps;
+
+ private ServerBootstrap bootstrap;
+ private ChannelFuture channelFuture;
+ private int port = -1;
+
+ /**
+ * Creates a TransportServer that binds to the given host and the given port, or to any available
+ * if 0. If you don't want to bind to any special host, set "hostToBind" to null.
+ * */
+ public TransportServer(
+ TransportContext context,
+ String hostToBind,
+ int portToBind,
+ RpcHandler appRpcHandler,
+ List<TransportServerBootstrap> bootstraps) {
+ this.context = context;
+ this.conf = context.getConf();
+ this.appRpcHandler = appRpcHandler;
+ this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps));
+
+ try {
+ init(hostToBind, portToBind);
+ } catch (RuntimeException e) {
+ JavaUtils.closeQuietly(this);
+ throw e;
+ }
+ }
+
+ public int getPort() {
+ if (port == -1) {
+ throw new IllegalStateException("Server not initialized");
+ }
+ return port;
+ }
+
+ private void init(String hostToBind, int portToBind) {
+
+ IOMode ioMode = IOMode.valueOf(conf.ioMode());
+ EventLoopGroup bossGroup =
+ NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server");
+ EventLoopGroup workerGroup = bossGroup;
+
+ PooledByteBufAllocator allocator = NettyUtils.createPooledByteBufAllocator(
+ conf.preferDirectBufs(), true /* allowCache */, conf.serverThreads());
+
+ bootstrap = new ServerBootstrap()
+ .group(bossGroup, workerGroup)
+ .channel(NettyUtils.getServerChannelClass(ioMode))
+ .option(ChannelOption.ALLOCATOR, allocator)
+ .childOption(ChannelOption.ALLOCATOR, allocator);
+
+ if (conf.backLog() > 0) {
+ bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());
+ }
+
+ if (conf.receiveBuf() > 0) {
+ bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf());
+ }
+
+ if (conf.sendBuf() > 0) {
+ bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf());
+ }
+
+ bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
+ @Override
+ protected void initChannel(SocketChannel ch) throws Exception {
+ RpcHandler rpcHandler = appRpcHandler;
+ for (TransportServerBootstrap bootstrap : bootstraps) {
+ rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
+ }
+ context.initializePipeline(ch, rpcHandler);
+ }
+ });
+
+ InetSocketAddress address = hostToBind == null ?
+ new InetSocketAddress(portToBind): new InetSocketAddress(hostToBind, portToBind);
+ channelFuture = bootstrap.bind(address);
+ channelFuture.syncUninterruptibly();
+
+ port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
+ logger.debug("Shuffle server started on port :" + port);
+ }
+
+ @Override
+ public void close() {
+ if (channelFuture != null) {
+ // close is a local operation and should finish within milliseconds; timeout just to be safe
+ channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS);
+ channelFuture = null;
+ }
+ if (bootstrap != null && bootstrap.group() != null) {
+ bootstrap.group().shutdownGracefully();
+ }
+ if (bootstrap != null && bootstrap.childGroup() != null) {
+ bootstrap.childGroup().shutdownGracefully();
+ }
+ bootstrap = null;
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java
new file mode 100644
index 0000000000..05803ab1bb
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import io.netty.channel.Channel;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a client connects
+ * to the server. This allows customizing the client channel to allow for things such as SASL
+ * authentication.
+ */
+public interface TransportServerBootstrap {
+ /**
+ * Customizes the channel to include new features, if needed.
+ *
+ * @param channel The connected channel opened by the client.
+ * @param rpcHandler The RPC handler for the server.
+ * @return The RPC handler to use for the channel.
+ */
+ RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler);
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java
new file mode 100644
index 0000000000..b141572004
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+
+/**
+ * A writable channel that stores the written data in a byte array in memory.
+ */
+public class ByteArrayWritableChannel implements WritableByteChannel {
+
+ private final byte[] data;
+ private int offset;
+
+ public ByteArrayWritableChannel(int size) {
+ this.data = new byte[size];
+ }
+
+ public byte[] getData() {
+ return data;
+ }
+
+ public int length() {
+ return offset;
+ }
+
+ /** Resets the channel so that writing to it will overwrite the existing buffer. */
+ public void reset() {
+ offset = 0;
+ }
+
+ /**
+ * Reads from the given buffer into the internal byte array.
+ */
+ @Override
+ public int write(ByteBuffer src) {
+ int toTransfer = Math.min(src.remaining(), data.length - offset);
+ src.get(data, offset, toTransfer);
+ offset += toTransfer;
+ return toTransfer;
+ }
+
+ @Override
+ public void close() {
+
+ }
+
+ @Override
+ public boolean isOpen() {
+ return true;
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java
new file mode 100644
index 0000000000..a2f018373f
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.util;
+
+public enum ByteUnit {
+ BYTE (1),
+ KiB (1024L),
+ MiB ((long) Math.pow(1024L, 2L)),
+ GiB ((long) Math.pow(1024L, 3L)),
+ TiB ((long) Math.pow(1024L, 4L)),
+ PiB ((long) Math.pow(1024L, 5L));
+
+ private ByteUnit(long multiplier) {
+ this.multiplier = multiplier;
+ }
+
+ // Interpret the provided number (d) with suffix (u) as this unit type.
+ // E.g. KiB.interpret(1, MiB) interprets 1MiB as its KiB representation = 1024k
+ public long convertFrom(long d, ByteUnit u) {
+ return u.convertTo(d, this);
+ }
+
+ // Convert the provided number (d) interpreted as this unit type to unit type (u).
+ public long convertTo(long d, ByteUnit u) {
+ if (multiplier > u.multiplier) {
+ long ratio = multiplier / u.multiplier;
+ if (Long.MAX_VALUE / ratio < d) {
+ throw new IllegalArgumentException("Conversion of " + d + " exceeds Long.MAX_VALUE in "
+ + name() + ". Try a larger unit (e.g. MiB instead of KiB)");
+ }
+ return d * ratio;
+ } else {
+ // Perform operations in this order to avoid potential overflow
+ // when computing d * multiplier
+ return d / (u.multiplier / multiplier);
+ }
+ }
+
+ public double toBytes(long d) {
+ if (d < 0) {
+ throw new IllegalArgumentException("Negative size value. Size must be positive: " + d);
+ }
+ return d * multiplier;
+ }
+
+ public long toKiB(long d) { return convertTo(d, KiB); }
+ public long toMiB(long d) { return convertTo(d, MiB); }
+ public long toGiB(long d) { return convertTo(d, GiB); }
+ public long toTiB(long d) { return convertTo(d, TiB); }
+ public long toPiB(long d) { return convertTo(d, PiB); }
+
+ private final long multiplier;
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java
new file mode 100644
index 0000000000..d944d9da1c
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.util.NoSuchElementException;
+
+/**
+ * Provides a mechanism for constructing a {@link TransportConf} using some sort of configuration.
+ */
+public abstract class ConfigProvider {
+ /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */
+ public abstract String get(String name);
+
+ public String get(String name, String defaultValue) {
+ try {
+ return get(name);
+ } catch (NoSuchElementException e) {
+ return defaultValue;
+ }
+ }
+
+ public int getInt(String name, int defaultValue) {
+ return Integer.parseInt(get(name, Integer.toString(defaultValue)));
+ }
+
+ public long getLong(String name, long defaultValue) {
+ return Long.parseLong(get(name, Long.toString(defaultValue)));
+ }
+
+ public double getDouble(String name, double defaultValue) {
+ return Double.parseDouble(get(name, Double.toString(defaultValue)));
+ }
+
+ public boolean getBoolean(String name, boolean defaultValue) {
+ return Boolean.parseBoolean(get(name, Boolean.toString(defaultValue)));
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/IOMode.java b/common/network-common/src/main/java/org/apache/spark/network/util/IOMode.java
new file mode 100644
index 0000000000..6b208d95bb
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/IOMode.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+/**
+ * Selector for which form of low-level IO we should use.
+ * NIO is always available, while EPOLL is only available on Linux.
+ * AUTO is used to select EPOLL if it's available, or NIO otherwise.
+ */
+public enum IOMode {
+ NIO, EPOLL
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java
new file mode 100644
index 0000000000..b3d8e0cd7c
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -0,0 +1,303 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.io.Closeable;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.concurrent.TimeUnit;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableMap;
+import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * General utilities available in the network package. Many of these are sourced from Spark's
+ * own Utils, just accessible within this package.
+ */
+public class JavaUtils {
+ private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class);
+
+ /**
+ * Define a default value for driver memory here since this value is referenced across the code
+ * base and nearly all files already use Utils.scala
+ */
+ public static final long DEFAULT_DRIVER_MEM_MB = 1024;
+
+ /** Closes the given object, ignoring IOExceptions. */
+ public static void closeQuietly(Closeable closeable) {
+ try {
+ if (closeable != null) {
+ closeable.close();
+ }
+ } catch (IOException e) {
+ logger.error("IOException should not have been thrown.", e);
+ }
+ }
+
+ /** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */
+ public static int nonNegativeHash(Object obj) {
+ if (obj == null) { return 0; }
+ int hash = obj.hashCode();
+ return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0;
+ }
+
+ /**
+ * Convert the given string to a byte buffer. The resulting buffer can be
+ * converted back to the same string through {@link #bytesToString(ByteBuffer)}.
+ */
+ public static ByteBuffer stringToBytes(String s) {
+ return Unpooled.wrappedBuffer(s.getBytes(Charsets.UTF_8)).nioBuffer();
+ }
+
+ /**
+ * Convert the given byte buffer to a string. The resulting string can be
+ * converted back to the same byte buffer through {@link #stringToBytes(String)}.
+ */
+ public static String bytesToString(ByteBuffer b) {
+ return Unpooled.wrappedBuffer(b).toString(Charsets.UTF_8);
+ }
+
+ /*
+ * Delete a file or directory and its contents recursively.
+ * Don't follow directories if they are symlinks.
+ * Throws an exception if deletion is unsuccessful.
+ */
+ public static void deleteRecursively(File file) throws IOException {
+ if (file == null) { return; }
+
+ if (file.isDirectory() && !isSymlink(file)) {
+ IOException savedIOException = null;
+ for (File child : listFilesSafely(file)) {
+ try {
+ deleteRecursively(child);
+ } catch (IOException e) {
+ // In case of multiple exceptions, only last one will be thrown
+ savedIOException = e;
+ }
+ }
+ if (savedIOException != null) {
+ throw savedIOException;
+ }
+ }
+
+ boolean deleted = file.delete();
+ // Delete can also fail if the file simply did not exist.
+ if (!deleted && file.exists()) {
+ throw new IOException("Failed to delete: " + file.getAbsolutePath());
+ }
+ }
+
+ private static File[] listFilesSafely(File file) throws IOException {
+ if (file.exists()) {
+ File[] files = file.listFiles();
+ if (files == null) {
+ throw new IOException("Failed to list files for dir: " + file);
+ }
+ return files;
+ } else {
+ return new File[0];
+ }
+ }
+
+ private static boolean isSymlink(File file) throws IOException {
+ Preconditions.checkNotNull(file);
+ File fileInCanonicalDir = null;
+ if (file.getParent() == null) {
+ fileInCanonicalDir = file;
+ } else {
+ fileInCanonicalDir = new File(file.getParentFile().getCanonicalFile(), file.getName());
+ }
+ return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile());
+ }
+
+ private static final ImmutableMap<String, TimeUnit> timeSuffixes =
+ ImmutableMap.<String, TimeUnit>builder()
+ .put("us", TimeUnit.MICROSECONDS)
+ .put("ms", TimeUnit.MILLISECONDS)
+ .put("s", TimeUnit.SECONDS)
+ .put("m", TimeUnit.MINUTES)
+ .put("min", TimeUnit.MINUTES)
+ .put("h", TimeUnit.HOURS)
+ .put("d", TimeUnit.DAYS)
+ .build();
+
+ private static final ImmutableMap<String, ByteUnit> byteSuffixes =
+ ImmutableMap.<String, ByteUnit>builder()
+ .put("b", ByteUnit.BYTE)
+ .put("k", ByteUnit.KiB)
+ .put("kb", ByteUnit.KiB)
+ .put("m", ByteUnit.MiB)
+ .put("mb", ByteUnit.MiB)
+ .put("g", ByteUnit.GiB)
+ .put("gb", ByteUnit.GiB)
+ .put("t", ByteUnit.TiB)
+ .put("tb", ByteUnit.TiB)
+ .put("p", ByteUnit.PiB)
+ .put("pb", ByteUnit.PiB)
+ .build();
+
+ /**
+ * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for
+ * internal use. If no suffix is provided a direct conversion is attempted.
+ */
+ private static long parseTimeString(String str, TimeUnit unit) {
+ String lower = str.toLowerCase().trim();
+
+ try {
+ Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower);
+ if (!m.matches()) {
+ throw new NumberFormatException("Failed to parse time string: " + str);
+ }
+
+ long val = Long.parseLong(m.group(1));
+ String suffix = m.group(2);
+
+ // Check for invalid suffixes
+ if (suffix != null && !timeSuffixes.containsKey(suffix)) {
+ throw new NumberFormatException("Invalid suffix: \"" + suffix + "\"");
+ }
+
+ // If suffix is valid use that, otherwise none was provided and use the default passed
+ return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit);
+ } catch (NumberFormatException e) {
+ String timeError = "Time must be specified as seconds (s), " +
+ "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " +
+ "E.g. 50s, 100ms, or 250us.";
+
+ throw new NumberFormatException(timeError + "\n" + e.getMessage());
+ }
+ }
+
+ /**
+ * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If
+ * no suffix is provided, the passed number is assumed to be in ms.
+ */
+ public static long timeStringAsMs(String str) {
+ return parseTimeString(str, TimeUnit.MILLISECONDS);
+ }
+
+ /**
+ * Convert a time parameter such as (50s, 100ms, or 250us) to seconds for internal use. If
+ * no suffix is provided, the passed number is assumed to be in seconds.
+ */
+ public static long timeStringAsSec(String str) {
+ return parseTimeString(str, TimeUnit.SECONDS);
+ }
+
+ /**
+ * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for
+ * internal use. If no suffix is provided a direct conversion of the provided default is
+ * attempted.
+ */
+ private static long parseByteString(String str, ByteUnit unit) {
+ String lower = str.toLowerCase().trim();
+
+ try {
+ Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower);
+ Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower);
+
+ if (m.matches()) {
+ long val = Long.parseLong(m.group(1));
+ String suffix = m.group(2);
+
+ // Check for invalid suffixes
+ if (suffix != null && !byteSuffixes.containsKey(suffix)) {
+ throw new NumberFormatException("Invalid suffix: \"" + suffix + "\"");
+ }
+
+ // If suffix is valid use that, otherwise none was provided and use the default passed
+ return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit);
+ } else if (fractionMatcher.matches()) {
+ throw new NumberFormatException("Fractional values are not supported. Input was: "
+ + fractionMatcher.group(1));
+ } else {
+ throw new NumberFormatException("Failed to parse byte string: " + str);
+ }
+
+ } catch (NumberFormatException e) {
+ String timeError = "Size must be specified as bytes (b), " +
+ "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " +
+ "E.g. 50b, 100k, or 250m.";
+
+ throw new NumberFormatException(timeError + "\n" + e.getMessage());
+ }
+ }
+
+ /**
+ * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for
+ * internal use.
+ *
+ * If no suffix is provided, the passed number is assumed to be in bytes.
+ */
+ public static long byteStringAsBytes(String str) {
+ return parseByteString(str, ByteUnit.BYTE);
+ }
+
+ /**
+ * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for
+ * internal use.
+ *
+ * If no suffix is provided, the passed number is assumed to be in kibibytes.
+ */
+ public static long byteStringAsKb(String str) {
+ return parseByteString(str, ByteUnit.KiB);
+ }
+
+ /**
+ * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for
+ * internal use.
+ *
+ * If no suffix is provided, the passed number is assumed to be in mebibytes.
+ */
+ public static long byteStringAsMb(String str) {
+ return parseByteString(str, ByteUnit.MiB);
+ }
+
+ /**
+ * Convert a passed byte string (e.g. 50b, 100k, or 250m) to gibibytes for
+ * internal use.
+ *
+ * If no suffix is provided, the passed number is assumed to be in gibibytes.
+ */
+ public static long byteStringAsGb(String str) {
+ return parseByteString(str, ByteUnit.GiB);
+ }
+
+ /**
+ * Returns a byte array with the buffer's contents, trying to avoid copying the data if
+ * possible.
+ */
+ public static byte[] bufferToArray(ByteBuffer buffer) {
+ if (buffer.hasArray() && buffer.arrayOffset() == 0 &&
+ buffer.array().length == buffer.remaining()) {
+ return buffer.array();
+ } else {
+ byte[] bytes = new byte[buffer.remaining()];
+ buffer.get(bytes);
+ return bytes;
+ }
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
new file mode 100644
index 0000000000..922c37a10e
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Based on LimitedInputStream.java from Google Guava
+ *
+ * Copyright (C) 2007 The Guava Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.io.FilterInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Wraps a {@link InputStream}, limiting the number of bytes which can be read.
+ *
+ * This code is from Guava's 14.0 source code, because there is no compatible way to
+ * use this functionality in both a Guava 11 environment and a Guava &gt;14 environment.
+ */
+public final class LimitedInputStream extends FilterInputStream {
+ private long left;
+ private long mark = -1;
+
+ public LimitedInputStream(InputStream in, long limit) {
+ super(in);
+ Preconditions.checkNotNull(in);
+ Preconditions.checkArgument(limit >= 0, "limit must be non-negative");
+ left = limit;
+ }
+ @Override public int available() throws IOException {
+ return (int) Math.min(in.available(), left);
+ }
+ // it's okay to mark even if mark isn't supported, as reset won't work
+ @Override public synchronized void mark(int readLimit) {
+ in.mark(readLimit);
+ mark = left;
+ }
+ @Override public int read() throws IOException {
+ if (left == 0) {
+ return -1;
+ }
+ int result = in.read();
+ if (result != -1) {
+ --left;
+ }
+ return result;
+ }
+ @Override public int read(byte[] b, int off, int len) throws IOException {
+ if (left == 0) {
+ return -1;
+ }
+ len = (int) Math.min(len, left);
+ int result = in.read(b, off, len);
+ if (result != -1) {
+ left -= result;
+ }
+ return result;
+ }
+ @Override public synchronized void reset() throws IOException {
+ if (!in.markSupported()) {
+ throw new IOException("Mark not supported");
+ }
+ if (mark == -1) {
+ throw new IOException("Mark not set");
+ }
+ in.reset();
+ left = mark;
+ }
+ @Override public long skip(long n) throws IOException {
+ n = Math.min(n, left);
+ long skipped = in.skip(n);
+ left -= skipped;
+ return skipped;
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java
new file mode 100644
index 0000000000..668d2356b9
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import com.google.common.collect.Maps;
+
+import java.util.Map;
+import java.util.NoSuchElementException;
+
+/** ConfigProvider based on a Map (copied in the constructor). */
+public class MapConfigProvider extends ConfigProvider {
+ private final Map<String, String> config;
+
+ public MapConfigProvider(Map<String, String> config) {
+ this.config = Maps.newHashMap(config);
+ }
+
+ @Override
+ public String get(String name) {
+ String value = config.get(name);
+ if (value == null) {
+ throw new NoSuchElementException(name);
+ }
+ return value;
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java
new file mode 100644
index 0000000000..caa7260bc8
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.lang.reflect.Field;
+import java.util.concurrent.ThreadFactory;
+
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.Channel;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.ServerChannel;
+import io.netty.channel.epoll.EpollEventLoopGroup;
+import io.netty.channel.epoll.EpollServerSocketChannel;
+import io.netty.channel.epoll.EpollSocketChannel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.handler.codec.ByteToMessageDecoder;
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
+import io.netty.util.internal.PlatformDependent;
+
+/**
+ * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO.
+ */
+public class NettyUtils {
+ /** Creates a new ThreadFactory which prefixes each thread with the given name. */
+ public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
+ return new ThreadFactoryBuilder()
+ .setDaemon(true)
+ .setNameFormat(threadPoolPrefix + "-%d")
+ .build();
+ }
+
+ /** Creates a Netty EventLoopGroup based on the IOMode. */
+ public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
+ ThreadFactory threadFactory = createThreadFactory(threadPrefix);
+
+ switch (mode) {
+ case NIO:
+ return new NioEventLoopGroup(numThreads, threadFactory);
+ case EPOLL:
+ return new EpollEventLoopGroup(numThreads, threadFactory);
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /** Returns the correct (client) SocketChannel class based on IOMode. */
+ public static Class<? extends Channel> getClientChannelClass(IOMode mode) {
+ switch (mode) {
+ case NIO:
+ return NioSocketChannel.class;
+ case EPOLL:
+ return EpollSocketChannel.class;
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /** Returns the correct ServerSocketChannel class based on IOMode. */
+ public static Class<? extends ServerChannel> getServerChannelClass(IOMode mode) {
+ switch (mode) {
+ case NIO:
+ return NioServerSocketChannel.class;
+ case EPOLL:
+ return EpollServerSocketChannel.class;
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /**
+ * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame.
+ * This is used before all decoders.
+ */
+ public static TransportFrameDecoder createFrameDecoder() {
+ return new TransportFrameDecoder();
+ }
+
+ /** Returns the remote address on the channel or "&lt;unknown remote&gt;" if none exists. */
+ public static String getRemoteAddress(Channel channel) {
+ if (channel != null && channel.remoteAddress() != null) {
+ return channel.remoteAddress().toString();
+ }
+ return "<unknown remote>";
+ }
+
+ /**
+ * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches
+ * are disabled for TransportClients because the ByteBufs are allocated by the event loop thread,
+ * but released by the executor thread rather than the event loop thread. Those thread-local
+ * caches actually delay the recycling of buffers, leading to larger memory usage.
+ */
+ public static PooledByteBufAllocator createPooledByteBufAllocator(
+ boolean allowDirectBufs,
+ boolean allowCache,
+ int numCores) {
+ if (numCores == 0) {
+ numCores = Runtime.getRuntime().availableProcessors();
+ }
+ return new PooledByteBufAllocator(
+ allowDirectBufs && PlatformDependent.directBufferPreferred(),
+ Math.min(getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), numCores),
+ Math.min(getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), allowDirectBufs ? numCores : 0),
+ getPrivateStaticField("DEFAULT_PAGE_SIZE"),
+ getPrivateStaticField("DEFAULT_MAX_ORDER"),
+ allowCache ? getPrivateStaticField("DEFAULT_TINY_CACHE_SIZE") : 0,
+ allowCache ? getPrivateStaticField("DEFAULT_SMALL_CACHE_SIZE") : 0,
+ allowCache ? getPrivateStaticField("DEFAULT_NORMAL_CACHE_SIZE") : 0
+ );
+ }
+
+ /** Used to get defaults from Netty's private static fields. */
+ private static int getPrivateStaticField(String name) {
+ try {
+ Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name);
+ f.setAccessible(true);
+ return f.getInt(null);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java
new file mode 100644
index 0000000000..5f20b70678
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.util.NoSuchElementException;
+
+import org.apache.spark.network.util.ConfigProvider;
+
+/** Uses System properties to obtain config values. */
+public class SystemPropertyConfigProvider extends ConfigProvider {
+ @Override
+ public String get(String name) {
+ String value = System.getProperty(name);
+ if (value == null) {
+ throw new NoSuchElementException(name);
+ }
+ return value;
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
new file mode 100644
index 0000000000..9f030da2b3
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import com.google.common.primitives.Ints;
+
+/**
+ * A central location that tracks all the settings we expose to users.
+ */
+public class TransportConf {
+
+ private final String SPARK_NETWORK_IO_MODE_KEY;
+ private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY;
+ private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY;
+ private final String SPARK_NETWORK_IO_BACKLOG_KEY;
+ private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY;
+ private final String SPARK_NETWORK_IO_SERVERTHREADS_KEY;
+ private final String SPARK_NETWORK_IO_CLIENTTHREADS_KEY;
+ private final String SPARK_NETWORK_IO_RECEIVEBUFFER_KEY;
+ private final String SPARK_NETWORK_IO_SENDBUFFER_KEY;
+ private final String SPARK_NETWORK_SASL_TIMEOUT_KEY;
+ private final String SPARK_NETWORK_IO_MAXRETRIES_KEY;
+ private final String SPARK_NETWORK_IO_RETRYWAIT_KEY;
+ private final String SPARK_NETWORK_IO_LAZYFD_KEY;
+
+ private final ConfigProvider conf;
+
+ private final String module;
+
+ public TransportConf(String module, ConfigProvider conf) {
+ this.module = module;
+ this.conf = conf;
+ SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode");
+ SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs");
+ SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout");
+ SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog");
+ SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY = getConfKey("io.numConnectionsPerPeer");
+ SPARK_NETWORK_IO_SERVERTHREADS_KEY = getConfKey("io.serverThreads");
+ SPARK_NETWORK_IO_CLIENTTHREADS_KEY = getConfKey("io.clientThreads");
+ SPARK_NETWORK_IO_RECEIVEBUFFER_KEY = getConfKey("io.receiveBuffer");
+ SPARK_NETWORK_IO_SENDBUFFER_KEY = getConfKey("io.sendBuffer");
+ SPARK_NETWORK_SASL_TIMEOUT_KEY = getConfKey("sasl.timeout");
+ SPARK_NETWORK_IO_MAXRETRIES_KEY = getConfKey("io.maxRetries");
+ SPARK_NETWORK_IO_RETRYWAIT_KEY = getConfKey("io.retryWait");
+ SPARK_NETWORK_IO_LAZYFD_KEY = getConfKey("io.lazyFD");
+ }
+
+ private String getConfKey(String suffix) {
+ return "spark." + module + "." + suffix;
+ }
+
+ /** IO mode: nio or epoll */
+ public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); }
+
+ /** If true, we will prefer allocating off-heap byte buffers within Netty. */
+ public boolean preferDirectBufs() {
+ return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true);
+ }
+
+ /** Connect timeout in milliseconds. Default 120 secs. */
+ public int connectionTimeoutMs() {
+ long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec(
+ conf.get("spark.network.timeout", "120s"));
+ long defaultTimeoutMs = JavaUtils.timeStringAsSec(
+ conf.get(SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY, defaultNetworkTimeoutS + "s")) * 1000;
+ return (int) defaultTimeoutMs;
+ }
+
+ /** Number of concurrent connections between two nodes for fetching data. */
+ public int numConnectionsPerPeer() {
+ return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1);
+ }
+
+ /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */
+ public int backLog() { return conf.getInt(SPARK_NETWORK_IO_BACKLOG_KEY, -1); }
+
+ /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */
+ public int serverThreads() { return conf.getInt(SPARK_NETWORK_IO_SERVERTHREADS_KEY, 0); }
+
+ /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */
+ public int clientThreads() { return conf.getInt(SPARK_NETWORK_IO_CLIENTTHREADS_KEY, 0); }
+
+ /**
+ * Receive buffer size (SO_RCVBUF).
+ * Note: the optimal size for receive buffer and send buffer should be
+ * latency * network_bandwidth.
+ * Assuming latency = 1ms, network_bandwidth = 10Gbps
+ * buffer size should be ~ 1.25MB
+ */
+ public int receiveBuf() { return conf.getInt(SPARK_NETWORK_IO_RECEIVEBUFFER_KEY, -1); }
+
+ /** Send buffer size (SO_SNDBUF). */
+ public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); }
+
+ /** Timeout for a single round trip of SASL token exchange, in milliseconds. */
+ public int saslRTTimeoutMs() {
+ return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000;
+ }
+
+ /**
+ * Max number of times we will try IO exceptions (such as connection timeouts) per request.
+ * If set to 0, we will not do any retries.
+ */
+ public int maxIORetries() { return conf.getInt(SPARK_NETWORK_IO_MAXRETRIES_KEY, 3); }
+
+ /**
+ * Time (in milliseconds) that we will wait in order to perform a retry after an IOException.
+ * Only relevant if maxIORetries &gt; 0.
+ */
+ public int ioRetryWaitTimeMs() {
+ return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_IO_RETRYWAIT_KEY, "5s")) * 1000;
+ }
+
+ /**
+ * Minimum size of a block that we should start using memory map rather than reading in through
+ * normal IO operations. This prevents Spark from memory mapping very small blocks. In general,
+ * memory mapping has high overhead for blocks close to or below the page size of the OS.
+ */
+ public int memoryMapBytes() {
+ return Ints.checkedCast(JavaUtils.byteStringAsBytes(
+ conf.get("spark.storage.memoryMapThreshold", "2m")));
+ }
+
+ /**
+ * Whether to initialize FileDescriptor lazily or not. If true, file descriptors are
+ * created only when data is going to be transferred. This can reduce the number of open files.
+ */
+ public boolean lazyFileDescriptor() {
+ return conf.getBoolean(SPARK_NETWORK_IO_LAZYFD_KEY, true);
+ }
+
+ /**
+ * Maximum number of retries when binding to a port before giving up.
+ */
+ public int portMaxRetries() {
+ return conf.getInt("spark.port.maxRetries", 16);
+ }
+
+ /**
+ * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled.
+ */
+ public int maxSaslEncryptedBlockSize() {
+ return Ints.checkedCast(JavaUtils.byteStringAsBytes(
+ conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k")));
+ }
+
+ /**
+ * Whether the server should enforce encryption on SASL-authenticated connections.
+ */
+ public boolean saslServerAlwaysEncrypt() {
+ return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
new file mode 100644
index 0000000000..a466c72915
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.util.Iterator;
+import java.util.LinkedList;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+
+/**
+ * A customized frame decoder that allows intercepting raw data.
+ * <p>
+ * This behaves like Netty's frame decoder (with harcoded parameters that match this library's
+ * needs), except it allows an interceptor to be installed to read data directly before it's
+ * framed.
+ * <p>
+ * Unlike Netty's frame decoder, each frame is dispatched to child handlers as soon as it's
+ * decoded, instead of building as many frames as the current buffer allows and dispatching
+ * all of them. This allows a child handler to install an interceptor if needed.
+ * <p>
+ * If an interceptor is installed, framing stops, and data is instead fed directly to the
+ * interceptor. When the interceptor indicates that it doesn't need to read any more data,
+ * framing resumes. Interceptors should not hold references to the data buffers provided
+ * to their handle() method.
+ */
+public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
+
+ public static final String HANDLER_NAME = "frameDecoder";
+ private static final int LENGTH_SIZE = 8;
+ private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE;
+ private static final int UNKNOWN_FRAME_SIZE = -1;
+
+ private final LinkedList<ByteBuf> buffers = new LinkedList<>();
+ private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE);
+
+ private long totalSize = 0;
+ private long nextFrameSize = UNKNOWN_FRAME_SIZE;
+ private volatile Interceptor interceptor;
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
+ ByteBuf in = (ByteBuf) data;
+ buffers.add(in);
+ totalSize += in.readableBytes();
+
+ while (!buffers.isEmpty()) {
+ // First, feed the interceptor, and if it's still, active, try again.
+ if (interceptor != null) {
+ ByteBuf first = buffers.getFirst();
+ int available = first.readableBytes();
+ if (feedInterceptor(first)) {
+ assert !first.isReadable() : "Interceptor still active but buffer has data.";
+ }
+
+ int read = available - first.readableBytes();
+ if (read == available) {
+ buffers.removeFirst().release();
+ }
+ totalSize -= read;
+ } else {
+ // Interceptor is not active, so try to decode one frame.
+ ByteBuf frame = decodeNext();
+ if (frame == null) {
+ break;
+ }
+ ctx.fireChannelRead(frame);
+ }
+ }
+ }
+
+ private long decodeFrameSize() {
+ if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < LENGTH_SIZE) {
+ return nextFrameSize;
+ }
+
+ // We know there's enough data. If the first buffer contains all the data, great. Otherwise,
+ // hold the bytes for the frame length in a composite buffer until we have enough data to read
+ // the frame size. Normally, it should be rare to need more than one buffer to read the frame
+ // size.
+ ByteBuf first = buffers.getFirst();
+ if (first.readableBytes() >= LENGTH_SIZE) {
+ nextFrameSize = first.readLong() - LENGTH_SIZE;
+ totalSize -= LENGTH_SIZE;
+ if (!first.isReadable()) {
+ buffers.removeFirst().release();
+ }
+ return nextFrameSize;
+ }
+
+ while (frameLenBuf.readableBytes() < LENGTH_SIZE) {
+ ByteBuf next = buffers.getFirst();
+ int toRead = Math.min(next.readableBytes(), LENGTH_SIZE - frameLenBuf.readableBytes());
+ frameLenBuf.writeBytes(next, toRead);
+ if (!next.isReadable()) {
+ buffers.removeFirst().release();
+ }
+ }
+
+ nextFrameSize = frameLenBuf.readLong() - LENGTH_SIZE;
+ totalSize -= LENGTH_SIZE;
+ frameLenBuf.clear();
+ return nextFrameSize;
+ }
+
+ private ByteBuf decodeNext() throws Exception {
+ long frameSize = decodeFrameSize();
+ if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) {
+ return null;
+ }
+
+ // Reset size for next frame.
+ nextFrameSize = UNKNOWN_FRAME_SIZE;
+
+ Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize);
+ Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize);
+
+ // If the first buffer holds the entire frame, return it.
+ int remaining = (int) frameSize;
+ if (buffers.getFirst().readableBytes() >= remaining) {
+ return nextBufferForFrame(remaining);
+ }
+
+ // Otherwise, create a composite buffer.
+ CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer();
+ while (remaining > 0) {
+ ByteBuf next = nextBufferForFrame(remaining);
+ remaining -= next.readableBytes();
+ frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes());
+ }
+ assert remaining == 0;
+ return frame;
+ }
+
+ /**
+ * Takes the first buffer in the internal list, and either adjust it to fit in the frame
+ * (by taking a slice out of it) or remove it from the internal list.
+ */
+ private ByteBuf nextBufferForFrame(int bytesToRead) {
+ ByteBuf buf = buffers.getFirst();
+ ByteBuf frame;
+
+ if (buf.readableBytes() > bytesToRead) {
+ frame = buf.retain().readSlice(bytesToRead);
+ totalSize -= bytesToRead;
+ } else {
+ frame = buf;
+ buffers.removeFirst();
+ totalSize -= frame.readableBytes();
+ }
+
+ return frame;
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ for (ByteBuf b : buffers) {
+ b.release();
+ }
+ if (interceptor != null) {
+ interceptor.channelInactive();
+ }
+ frameLenBuf.release();
+ super.channelInactive(ctx);
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
+ if (interceptor != null) {
+ interceptor.exceptionCaught(cause);
+ }
+ super.exceptionCaught(ctx, cause);
+ }
+
+ public void setInterceptor(Interceptor interceptor) {
+ Preconditions.checkState(this.interceptor == null, "Already have an interceptor.");
+ this.interceptor = interceptor;
+ }
+
+ /**
+ * @return Whether the interceptor is still active after processing the data.
+ */
+ private boolean feedInterceptor(ByteBuf buf) throws Exception {
+ if (interceptor != null && !interceptor.handle(buf)) {
+ interceptor = null;
+ }
+ return interceptor != null;
+ }
+
+ public static interface Interceptor {
+
+ /**
+ * Handles data received from the remote end.
+ *
+ * @param data Buffer containing data.
+ * @return "true" if the interceptor expects more data, "false" to uninstall the interceptor.
+ */
+ boolean handle(ByteBuf data) throws Exception;
+
+ /** Called if an exception is thrown in the channel pipeline. */
+ void exceptionCaught(Throwable cause) throws Exception;
+
+ /** Called if the channel is closed and the interceptor is still installed. */
+ void channelInactive() throws Exception;
+
+ }
+
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
new file mode 100644
index 0000000000..70c849d60e
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.io.File;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import com.google.common.io.Closeables;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class ChunkFetchIntegrationSuite {
+ static final long STREAM_ID = 1;
+ static final int BUFFER_CHUNK_INDEX = 0;
+ static final int FILE_CHUNK_INDEX = 1;
+
+ static TransportServer server;
+ static TransportClientFactory clientFactory;
+ static StreamManager streamManager;
+ static File testFile;
+
+ static ManagedBuffer bufferChunk;
+ static ManagedBuffer fileChunk;
+
+ private TransportConf transportConf;
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ int bufSize = 100000;
+ final ByteBuffer buf = ByteBuffer.allocate(bufSize);
+ for (int i = 0; i < bufSize; i ++) {
+ buf.put((byte) i);
+ }
+ buf.flip();
+ bufferChunk = new NioManagedBuffer(buf);
+
+ testFile = File.createTempFile("shuffle-test-file", "txt");
+ testFile.deleteOnExit();
+ RandomAccessFile fp = new RandomAccessFile(testFile, "rw");
+ boolean shouldSuppressIOException = true;
+ try {
+ byte[] fileContent = new byte[1024];
+ new Random().nextBytes(fileContent);
+ fp.write(fileContent);
+ shouldSuppressIOException = false;
+ } finally {
+ Closeables.close(fp, shouldSuppressIOException);
+ }
+
+ final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+ fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25);
+
+ streamManager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ assertEquals(STREAM_ID, streamId);
+ if (chunkIndex == BUFFER_CHUNK_INDEX) {
+ return new NioManagedBuffer(buf);
+ } else if (chunkIndex == FILE_CHUNK_INDEX) {
+ return new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25);
+ } else {
+ throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex);
+ }
+ }
+ };
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return streamManager;
+ }
+ };
+ TransportContext context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ bufferChunk.release();
+ server.close();
+ clientFactory.close();
+ testFile.delete();
+ }
+
+ class FetchResult {
+ public Set<Integer> successChunks;
+ public Set<Integer> failedChunks;
+ public List<ManagedBuffer> buffers;
+
+ public void releaseBuffers() {
+ for (ManagedBuffer buffer : buffers) {
+ buffer.release();
+ }
+ }
+ }
+
+ private FetchResult fetchChunks(List<Integer> chunkIndices) throws Exception {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ final Semaphore sem = new Semaphore(0);
+
+ final FetchResult res = new FetchResult();
+ res.successChunks = Collections.synchronizedSet(new HashSet<Integer>());
+ res.failedChunks = Collections.synchronizedSet(new HashSet<Integer>());
+ res.buffers = Collections.synchronizedList(new LinkedList<ManagedBuffer>());
+
+ ChunkReceivedCallback callback = new ChunkReceivedCallback() {
+ @Override
+ public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ buffer.retain();
+ res.successChunks.add(chunkIndex);
+ res.buffers.add(buffer);
+ sem.release();
+ }
+
+ @Override
+ public void onFailure(int chunkIndex, Throwable e) {
+ res.failedChunks.add(chunkIndex);
+ sem.release();
+ }
+ };
+
+ for (int chunkIndex : chunkIndices) {
+ client.fetchChunk(STREAM_ID, chunkIndex, callback);
+ }
+ if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server");
+ }
+ client.close();
+ return res;
+ }
+
+ @Test
+ public void fetchBufferChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchFileChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchNonExistentChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(12345));
+ assertTrue(res.successChunks.isEmpty());
+ assertEquals(res.failedChunks, Sets.newHashSet(12345));
+ assertTrue(res.buffers.isEmpty());
+ }
+
+ @Test
+ public void fetchBothChunks() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchChunkAndNonExistent() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX));
+ assertEquals(res.failedChunks, Sets.newHashSet(12345));
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk));
+ res.releaseBuffers();
+ }
+
+ private void assertBufferListsEqual(List<ManagedBuffer> list0, List<ManagedBuffer> list1)
+ throws Exception {
+ assertEquals(list0.size(), list1.size());
+ for (int i = 0; i < list0.size(); i ++) {
+ assertBuffersEqual(list0.get(i), list1.get(i));
+ }
+ }
+
+ private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception {
+ ByteBuffer nio0 = buffer0.nioByteBuffer();
+ ByteBuffer nio1 = buffer1.nioByteBuffer();
+
+ int len = nio0.remaining();
+ assertEquals(nio0.remaining(), nio1.remaining());
+ for (int i = 0; i < len; i ++) {
+ assertEquals(nio0.get(), nio1.get());
+ }
+ }
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java
new file mode 100644
index 0000000000..6c8dd742f4
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.util.List;
+
+import com.google.common.primitives.Ints;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.FileRegion;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.MessageToMessageEncoder;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.Message;
+import org.apache.spark.network.protocol.MessageDecoder;
+import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.protocol.OneWayMessage;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamRequest;
+import org.apache.spark.network.protocol.StreamResponse;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.NettyUtils;
+
+public class ProtocolSuite {
+ private void testServerToClient(Message msg) {
+ EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
+ new MessageEncoder());
+ serverChannel.writeOutbound(msg);
+
+ EmbeddedChannel clientChannel = new EmbeddedChannel(
+ NettyUtils.createFrameDecoder(), new MessageDecoder());
+
+ while (!serverChannel.outboundMessages().isEmpty()) {
+ clientChannel.writeInbound(serverChannel.readOutbound());
+ }
+
+ assertEquals(1, clientChannel.inboundMessages().size());
+ assertEquals(msg, clientChannel.readInbound());
+ }
+
+ private void testClientToServer(Message msg) {
+ EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
+ new MessageEncoder());
+ clientChannel.writeOutbound(msg);
+
+ EmbeddedChannel serverChannel = new EmbeddedChannel(
+ NettyUtils.createFrameDecoder(), new MessageDecoder());
+
+ while (!clientChannel.outboundMessages().isEmpty()) {
+ serverChannel.writeInbound(clientChannel.readOutbound());
+ }
+
+ assertEquals(1, serverChannel.inboundMessages().size());
+ assertEquals(msg, serverChannel.readInbound());
+ }
+
+ @Test
+ public void requests() {
+ testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2)));
+ testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0)));
+ testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10)));
+ testClientToServer(new StreamRequest("abcde"));
+ testClientToServer(new OneWayMessage(new TestManagedBuffer(10)));
+ }
+
+ @Test
+ public void responses() {
+ testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10)));
+ testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0)));
+ testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error"));
+ testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), ""));
+ testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0)));
+ testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100)));
+ testServerToClient(new RpcFailure(0, "this is an error"));
+ testServerToClient(new RpcFailure(0, ""));
+ // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the
+ // channel and cannot be tested like this.
+ testServerToClient(new StreamResponse("anId", 12345L, new TestManagedBuffer(0)));
+ testServerToClient(new StreamFailure("anId", "this is an error"));
+ }
+
+ /**
+ * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer
+ * bytes, but messages, so this is needed so that the frame decoder on the receiving side can
+ * understand what MessageWithHeader actually contains.
+ */
+ private static class FileRegionEncoder extends MessageToMessageEncoder<FileRegion> {
+
+ @Override
+ public void encode(ChannelHandlerContext ctx, FileRegion in, List<Object> out)
+ throws Exception {
+
+ ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count()));
+ while (in.transfered() < in.count()) {
+ in.transferTo(channel, in.transfered());
+ }
+ out.add(Unpooled.wrappedBuffer(channel.getData()));
+ }
+
+ }
+
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
new file mode 100644
index 0000000000..f9b5bf96d6
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import com.google.common.collect.Maps;
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.MapConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+import org.junit.*;
+import static org.junit.Assert.*;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.*;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Suite which ensures that requests that go without a response for the network timeout period are
+ * failed, and the connection closed.
+ *
+ * In this suite, we use 2 seconds as the connection timeout, with some slack given in the tests,
+ * to ensure stability in different test environments.
+ */
+public class RequestTimeoutIntegrationSuite {
+
+ private TransportServer server;
+ private TransportClientFactory clientFactory;
+
+ private StreamManager defaultManager;
+ private TransportConf conf;
+
+ // A large timeout that "shouldn't happen", for the sake of faulty tests not hanging forever.
+ private final int FOREVER = 60 * 1000;
+
+ @Before
+ public void setUp() throws Exception {
+ Map<String, String> configMap = Maps.newHashMap();
+ configMap.put("spark.shuffle.io.connectionTimeout", "2s");
+ conf = new TransportConf("shuffle", new MapConfigProvider(configMap));
+
+ defaultManager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
+ @After
+ public void tearDown() {
+ if (server != null) {
+ server.close();
+ }
+ if (clientFactory != null) {
+ clientFactory.close();
+ }
+ }
+
+ // Basic suite: First request completes quickly, and second waits for longer than network timeout.
+ @Test
+ public void timeoutInactiveRequests() throws Exception {
+ final Semaphore semaphore = new Semaphore(1);
+ final int responseSize = 16;
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ try {
+ semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
+ callback.onSuccess(ByteBuffer.allocate(responseSize));
+ } catch (InterruptedException e) {
+ // do nothing
+ }
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return defaultManager;
+ }
+ };
+
+ TransportContext context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+
+ // First completes quickly (semaphore starts at 1).
+ TestCallback callback0 = new TestCallback();
+ synchronized (callback0) {
+ client.sendRpc(ByteBuffer.allocate(0), callback0);
+ callback0.wait(FOREVER);
+ assertEquals(responseSize, callback0.successLength);
+ }
+
+ // Second times out after 2 seconds, with slack. Must be IOException.
+ TestCallback callback1 = new TestCallback();
+ synchronized (callback1) {
+ client.sendRpc(ByteBuffer.allocate(0), callback1);
+ callback1.wait(4 * 1000);
+ assert (callback1.failure != null);
+ assert (callback1.failure instanceof IOException);
+ }
+ semaphore.release();
+ }
+
+ // A timeout will cause the connection to be closed, invalidating the current TransportClient.
+ // It should be the case that requesting a client from the factory produces a new, valid one.
+ @Test
+ public void timeoutCleanlyClosesClient() throws Exception {
+ final Semaphore semaphore = new Semaphore(0);
+ final int responseSize = 16;
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ try {
+ semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
+ callback.onSuccess(ByteBuffer.allocate(responseSize));
+ } catch (InterruptedException e) {
+ // do nothing
+ }
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return defaultManager;
+ }
+ };
+
+ TransportContext context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+
+ // First request should eventually fail.
+ TransportClient client0 =
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ TestCallback callback0 = new TestCallback();
+ synchronized (callback0) {
+ client0.sendRpc(ByteBuffer.allocate(0), callback0);
+ callback0.wait(FOREVER);
+ assert (callback0.failure instanceof IOException);
+ assert (!client0.isActive());
+ }
+
+ // Increment the semaphore and the second request should succeed quickly.
+ semaphore.release(2);
+ TransportClient client1 =
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ TestCallback callback1 = new TestCallback();
+ synchronized (callback1) {
+ client1.sendRpc(ByteBuffer.allocate(0), callback1);
+ callback1.wait(FOREVER);
+ assertEquals(responseSize, callback1.successLength);
+ assertNull(callback1.failure);
+ }
+ }
+
+ // The timeout is relative to the LAST request sent, which is kinda weird, but still.
+ // This test also makes sure the timeout works for Fetch requests as well as RPCs.
+ @Test
+ public void furtherRequestsDelay() throws Exception {
+ final byte[] response = new byte[16];
+ final StreamManager manager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS);
+ return new NioManagedBuffer(ByteBuffer.wrap(response));
+ }
+ };
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return manager;
+ }
+ };
+
+ TransportContext context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+
+ // Send one request, which will eventually fail.
+ TestCallback callback0 = new TestCallback();
+ client.fetchChunk(0, 0, callback0);
+ Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS);
+
+ // Send a second request before the first has failed.
+ TestCallback callback1 = new TestCallback();
+ client.fetchChunk(0, 1, callback1);
+ Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS);
+
+ synchronized (callback0) {
+ // not complete yet, but should complete soon
+ assertEquals(-1, callback0.successLength);
+ assertNull(callback0.failure);
+ callback0.wait(2 * 1000);
+ assertTrue(callback0.failure instanceof IOException);
+ }
+
+ synchronized (callback1) {
+ // failed at same time as previous
+ assert (callback0.failure instanceof IOException);
+ }
+ }
+
+ /**
+ * Callback which sets 'success' or 'failure' on completion.
+ * Additionally notifies all waiters on this callback when invoked.
+ */
+ class TestCallback implements RpcResponseCallback, ChunkReceivedCallback {
+
+ int successLength = -1;
+ Throwable failure;
+
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ synchronized(this) {
+ successLength = response.remaining();
+ this.notifyAll();
+ }
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ synchronized(this) {
+ failure = e;
+ this.notifyAll();
+ }
+ }
+
+ @Override
+ public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ synchronized(this) {
+ try {
+ successLength = buffer.nioByteBuffer().remaining();
+ this.notifyAll();
+ } catch (IOException e) {
+ // weird
+ }
+ }
+ }
+
+ @Override
+ public void onFailure(int chunkIndex, Throwable e) {
+ synchronized(this) {
+ failure = e;
+ this.notifyAll();
+ }
+ }
+ }
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
new file mode 100644
index 0000000000..9e9be98c14
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Sets;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class RpcIntegrationSuite {
+ static TransportServer server;
+ static TransportClientFactory clientFactory;
+ static RpcHandler rpcHandler;
+ static List<String> oneWayMsgs;
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+ rpcHandler = new RpcHandler() {
+ @Override
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ String msg = JavaUtils.bytesToString(message);
+ String[] parts = msg.split("/");
+ if (parts[0].equals("hello")) {
+ callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!"));
+ } else if (parts[0].equals("return error")) {
+ callback.onFailure(new RuntimeException("Returned: " + parts[1]));
+ } else if (parts[0].equals("throw error")) {
+ throw new RuntimeException("Thrown: " + parts[1]);
+ }
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message) {
+ oneWayMsgs.add(JavaUtils.bytesToString(message));
+ }
+
+ @Override
+ public StreamManager getStreamManager() { return new OneForOneStreamManager(); }
+ };
+ TransportContext context = new TransportContext(conf, rpcHandler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ oneWayMsgs = new ArrayList<>();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ server.close();
+ clientFactory.close();
+ }
+
+ class RpcResult {
+ public Set<String> successMessages;
+ public Set<String> errorMessages;
+ }
+
+ private RpcResult sendRPC(String ... commands) throws Exception {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ final Semaphore sem = new Semaphore(0);
+
+ final RpcResult res = new RpcResult();
+ res.successMessages = Collections.synchronizedSet(new HashSet<String>());
+ res.errorMessages = Collections.synchronizedSet(new HashSet<String>());
+
+ RpcResponseCallback callback = new RpcResponseCallback() {
+ @Override
+ public void onSuccess(ByteBuffer message) {
+ String response = JavaUtils.bytesToString(message);
+ res.successMessages.add(response);
+ sem.release();
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ res.errorMessages.add(e.getMessage());
+ sem.release();
+ }
+ };
+
+ for (String command : commands) {
+ client.sendRpc(JavaUtils.stringToBytes(command), callback);
+ }
+
+ if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server");
+ }
+ client.close();
+ return res;
+ }
+
+ @Test
+ public void singleRPC() throws Exception {
+ RpcResult res = sendRPC("hello/Aaron");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!"));
+ assertTrue(res.errorMessages.isEmpty());
+ }
+
+ @Test
+ public void doubleRPC() throws Exception {
+ RpcResult res = sendRPC("hello/Aaron", "hello/Reynold");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!"));
+ assertTrue(res.errorMessages.isEmpty());
+ }
+
+ @Test
+ public void returnErrorRPC() throws Exception {
+ RpcResult res = sendRPC("return error/OK");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK"));
+ }
+
+ @Test
+ public void throwErrorRPC() throws Exception {
+ RpcResult res = sendRPC("throw error/uh-oh");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: uh-oh"));
+ }
+
+ @Test
+ public void doubleTrouble() throws Exception {
+ RpcResult res = sendRPC("return error/OK", "throw error/uh-oh");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK", "Thrown: uh-oh"));
+ }
+
+ @Test
+ public void sendSuccessAndFailure() throws Exception {
+ RpcResult res = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Bob!", "Hello, Builder!"));
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !"));
+ }
+
+ @Test
+ public void sendOneWayMessage() throws Exception {
+ final String message = "no reply";
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ client.send(JavaUtils.stringToBytes(message));
+ assertEquals(0, client.getHandler().numOutstandingRequests());
+
+ // Make sure the message arrives.
+ long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
+ while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) {
+ TimeUnit.MILLISECONDS.sleep(10);
+ }
+
+ assertEquals(1, oneWayMsgs.size());
+ assertEquals(message, oneWayMsgs.get(0));
+ } finally {
+ client.close();
+ }
+ }
+
+ private void assertErrorsContain(Set<String> errors, Set<String> contains) {
+ assertEquals(contains.size(), errors.size());
+
+ Set<String> remainingErrors = Sets.newHashSet(errors);
+ for (String contain : contains) {
+ Iterator<String> it = remainingErrors.iterator();
+ boolean foundMatch = false;
+ while (it.hasNext()) {
+ if (it.next().contains(contain)) {
+ it.remove();
+ foundMatch = true;
+ break;
+ }
+ }
+ assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch);
+ }
+
+ assertTrue(remainingErrors.isEmpty());
+ }
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java
new file mode 100644
index 0000000000..9c49556927
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java
@@ -0,0 +1,349 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.io.Files;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.StreamCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class StreamSuite {
+ private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" };
+
+ private static TransportServer server;
+ private static TransportClientFactory clientFactory;
+ private static File testFile;
+ private static File tempDir;
+
+ private static ByteBuffer emptyBuffer;
+ private static ByteBuffer smallBuffer;
+ private static ByteBuffer largeBuffer;
+
+ private static ByteBuffer createBuffer(int bufSize) {
+ ByteBuffer buf = ByteBuffer.allocate(bufSize);
+ for (int i = 0; i < bufSize; i ++) {
+ buf.put((byte) i);
+ }
+ buf.flip();
+ return buf;
+ }
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ tempDir = Files.createTempDir();
+ emptyBuffer = createBuffer(0);
+ smallBuffer = createBuffer(100);
+ largeBuffer = createBuffer(100000);
+
+ testFile = File.createTempFile("stream-test-file", "txt", tempDir);
+ FileOutputStream fp = new FileOutputStream(testFile);
+ try {
+ Random rnd = new Random();
+ for (int i = 0; i < 512; i++) {
+ byte[] fileContent = new byte[1024];
+ rnd.nextBytes(fileContent);
+ fp.write(fileContent);
+ }
+ } finally {
+ fp.close();
+ }
+
+ final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+ final StreamManager streamManager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public ManagedBuffer openStream(String streamId) {
+ switch (streamId) {
+ case "largeBuffer":
+ return new NioManagedBuffer(largeBuffer);
+ case "smallBuffer":
+ return new NioManagedBuffer(smallBuffer);
+ case "emptyBuffer":
+ return new NioManagedBuffer(emptyBuffer);
+ case "file":
+ return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length());
+ default:
+ throw new IllegalArgumentException("Invalid stream: " + streamId);
+ }
+ }
+ };
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return streamManager;
+ }
+ };
+ TransportContext context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ server.close();
+ clientFactory.close();
+ if (tempDir != null) {
+ for (File f : tempDir.listFiles()) {
+ f.delete();
+ }
+ tempDir.delete();
+ }
+ }
+
+ @Test
+ public void testZeroLengthStream() throws Throwable {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5));
+ task.run();
+ task.check();
+ } finally {
+ client.close();
+ }
+ }
+
+ @Test
+ public void testSingleStream() throws Throwable {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ StreamTask task = new StreamTask(client, "largeBuffer", TimeUnit.SECONDS.toMillis(5));
+ task.run();
+ task.check();
+ } finally {
+ client.close();
+ }
+ }
+
+ @Test
+ public void testMultipleStreams() throws Throwable {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ for (int i = 0; i < 20; i++) {
+ StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length],
+ TimeUnit.SECONDS.toMillis(5));
+ task.run();
+ task.check();
+ }
+ } finally {
+ client.close();
+ }
+ }
+
+ @Test
+ public void testConcurrentStreams() throws Throwable {
+ ExecutorService executor = Executors.newFixedThreadPool(20);
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+
+ try {
+ List<StreamTask> tasks = new ArrayList<>();
+ for (int i = 0; i < 20; i++) {
+ StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length],
+ TimeUnit.SECONDS.toMillis(20));
+ tasks.add(task);
+ executor.submit(task);
+ }
+
+ executor.shutdown();
+ assertTrue("Timed out waiting for tasks.", executor.awaitTermination(30, TimeUnit.SECONDS));
+ for (StreamTask task : tasks) {
+ task.check();
+ }
+ } finally {
+ executor.shutdownNow();
+ client.close();
+ }
+ }
+
+ private static class StreamTask implements Runnable {
+
+ private final TransportClient client;
+ private final String streamId;
+ private final long timeoutMs;
+ private Throwable error;
+
+ StreamTask(TransportClient client, String streamId, long timeoutMs) {
+ this.client = client;
+ this.streamId = streamId;
+ this.timeoutMs = timeoutMs;
+ }
+
+ @Override
+ public void run() {
+ ByteBuffer srcBuffer = null;
+ OutputStream out = null;
+ File outFile = null;
+ try {
+ ByteArrayOutputStream baos = null;
+
+ switch (streamId) {
+ case "largeBuffer":
+ baos = new ByteArrayOutputStream();
+ out = baos;
+ srcBuffer = largeBuffer;
+ break;
+ case "smallBuffer":
+ baos = new ByteArrayOutputStream();
+ out = baos;
+ srcBuffer = smallBuffer;
+ break;
+ case "file":
+ outFile = File.createTempFile("data", ".tmp", tempDir);
+ out = new FileOutputStream(outFile);
+ break;
+ case "emptyBuffer":
+ baos = new ByteArrayOutputStream();
+ out = baos;
+ srcBuffer = emptyBuffer;
+ break;
+ default:
+ throw new IllegalArgumentException(streamId);
+ }
+
+ TestCallback callback = new TestCallback(out);
+ client.stream(streamId, callback);
+ waitForCompletion(callback);
+
+ if (srcBuffer == null) {
+ assertTrue("File stream did not match.", Files.equal(testFile, outFile));
+ } else {
+ ByteBuffer base;
+ synchronized (srcBuffer) {
+ base = srcBuffer.duplicate();
+ }
+ byte[] result = baos.toByteArray();
+ byte[] expected = new byte[base.remaining()];
+ base.get(expected);
+ assertEquals(expected.length, result.length);
+ assertTrue("buffers don't match", Arrays.equals(expected, result));
+ }
+ } catch (Throwable t) {
+ error = t;
+ } finally {
+ if (out != null) {
+ try {
+ out.close();
+ } catch (Exception e) {
+ // ignore.
+ }
+ }
+ if (outFile != null) {
+ outFile.delete();
+ }
+ }
+ }
+
+ public void check() throws Throwable {
+ if (error != null) {
+ throw error;
+ }
+ }
+
+ private void waitForCompletion(TestCallback callback) throws Exception {
+ long now = System.currentTimeMillis();
+ long deadline = now + timeoutMs;
+ synchronized (callback) {
+ while (!callback.completed && now < deadline) {
+ callback.wait(deadline - now);
+ now = System.currentTimeMillis();
+ }
+ }
+ assertTrue("Timed out waiting for stream.", callback.completed);
+ assertNull(callback.error);
+ }
+
+ }
+
+ private static class TestCallback implements StreamCallback {
+
+ private final OutputStream out;
+ public volatile boolean completed;
+ public volatile Throwable error;
+
+ TestCallback(OutputStream out) {
+ this.out = out;
+ this.completed = false;
+ }
+
+ @Override
+ public void onData(String streamId, ByteBuffer buf) throws IOException {
+ byte[] tmp = new byte[buf.remaining()];
+ buf.get(tmp);
+ out.write(tmp);
+ }
+
+ @Override
+ public void onComplete(String streamId) throws IOException {
+ out.close();
+ synchronized (this) {
+ completed = true;
+ notifyAll();
+ }
+ }
+
+ @Override
+ public void onFailure(String streamId, Throwable cause) {
+ error = cause;
+ synchronized (this) {
+ completed = true;
+ notifyAll();
+ }
+ }
+
+ }
+
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java
new file mode 100644
index 0000000000..83c90f9eff
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1).
+ *
+ * Used for testing.
+ */
+public class TestManagedBuffer extends ManagedBuffer {
+
+ private final int len;
+ private NettyManagedBuffer underlying;
+
+ public TestManagedBuffer(int len) {
+ Preconditions.checkArgument(len <= Byte.MAX_VALUE);
+ this.len = len;
+ byte[] byteArray = new byte[len];
+ for (int i = 0; i < len; i ++) {
+ byteArray[i] = (byte) i;
+ }
+ this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray));
+ }
+
+
+ @Override
+ public long size() {
+ return underlying.size();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return underlying.nioByteBuffer();
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ return underlying.createInputStream();
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ underlying.retain();
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ underlying.release();
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return underlying.convertToNetty();
+ }
+
+ @Override
+ public int hashCode() {
+ return underlying.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ManagedBuffer) {
+ try {
+ ByteBuffer nioBuf = ((ManagedBuffer) other).nioByteBuffer();
+ if (nioBuf.remaining() != len) {
+ return false;
+ } else {
+ for (int i = 0; i < len; i ++) {
+ if (nioBuf.get() != i) {
+ return false;
+ }
+ }
+ return true;
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ return false;
+ }
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/TestUtils.java b/common/network-common/src/test/java/org/apache/spark/network/TestUtils.java
new file mode 100644
index 0000000000..56a2b805f1
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/TestUtils.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.net.InetAddress;
+
+public class TestUtils {
+ public static String getLocalHost() {
+ try {
+ return InetAddress.getLocalHost().getHostAddress();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
new file mode 100644
index 0000000000..dac7d4a5b0
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.NoSuchElementException;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.Maps;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.NoOpRpcHandler;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.ConfigProvider;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.MapConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class TransportClientFactorySuite {
+ private TransportConf conf;
+ private TransportContext context;
+ private TransportServer server1;
+ private TransportServer server2;
+
+ @Before
+ public void setUp() {
+ conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+ RpcHandler rpcHandler = new NoOpRpcHandler();
+ context = new TransportContext(conf, rpcHandler);
+ server1 = context.createServer();
+ server2 = context.createServer();
+ }
+
+ @After
+ public void tearDown() {
+ JavaUtils.closeQuietly(server1);
+ JavaUtils.closeQuietly(server2);
+ }
+
+ /**
+ * Request a bunch of clients to a single server to test
+ * we create up to maxConnections of clients.
+ *
+ * If concurrent is true, create multiple threads to create clients in parallel.
+ */
+ private void testClientReuse(final int maxConnections, boolean concurrent)
+ throws IOException, InterruptedException {
+
+ Map<String, String> configMap = Maps.newHashMap();
+ configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections));
+ TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap));
+
+ RpcHandler rpcHandler = new NoOpRpcHandler();
+ TransportContext context = new TransportContext(conf, rpcHandler);
+ final TransportClientFactory factory = context.createClientFactory();
+ final Set<TransportClient> clients = Collections.synchronizedSet(
+ new HashSet<TransportClient>());
+
+ final AtomicInteger failed = new AtomicInteger();
+ Thread[] attempts = new Thread[maxConnections * 10];
+
+ // Launch a bunch of threads to create new clients.
+ for (int i = 0; i < attempts.length; i++) {
+ attempts[i] = new Thread() {
+ @Override
+ public void run() {
+ try {
+ TransportClient client =
+ factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ assert (client.isActive());
+ clients.add(client);
+ } catch (IOException e) {
+ failed.incrementAndGet();
+ }
+ }
+ };
+
+ if (concurrent) {
+ attempts[i].start();
+ } else {
+ attempts[i].run();
+ }
+ }
+
+ // Wait until all the threads complete.
+ for (int i = 0; i < attempts.length; i++) {
+ attempts[i].join();
+ }
+
+ assert(failed.get() == 0);
+ assert(clients.size() == maxConnections);
+
+ for (TransportClient client : clients) {
+ client.close();
+ }
+ }
+
+ @Test
+ public void reuseClientsUpToConfigVariable() throws Exception {
+ testClientReuse(1, false);
+ testClientReuse(2, false);
+ testClientReuse(3, false);
+ testClientReuse(4, false);
+ }
+
+ @Test
+ public void reuseClientsUpToConfigVariableConcurrent() throws Exception {
+ testClientReuse(1, true);
+ testClientReuse(2, true);
+ testClientReuse(3, true);
+ testClientReuse(4, true);
+ }
+
+ @Test
+ public void returnDifferentClientsForDifferentServers() throws IOException {
+ TransportClientFactory factory = context.createClientFactory();
+ TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
+ assertTrue(c1.isActive());
+ assertTrue(c2.isActive());
+ assertTrue(c1 != c2);
+ factory.close();
+ }
+
+ @Test
+ public void neverReturnInactiveClients() throws IOException, InterruptedException {
+ TransportClientFactory factory = context.createClientFactory();
+ TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ c1.close();
+
+ long start = System.currentTimeMillis();
+ while (c1.isActive() && (System.currentTimeMillis() - start) < 3000) {
+ Thread.sleep(10);
+ }
+ assertFalse(c1.isActive());
+
+ TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ assertFalse(c1 == c2);
+ assertTrue(c2.isActive());
+ factory.close();
+ }
+
+ @Test
+ public void closeBlockClientsWithFactory() throws IOException {
+ TransportClientFactory factory = context.createClientFactory();
+ TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
+ assertTrue(c1.isActive());
+ assertTrue(c2.isActive());
+ factory.close();
+ assertFalse(c1.isActive());
+ assertFalse(c2.isActive());
+ }
+
+ @Test
+ public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException {
+ TransportConf conf = new TransportConf("shuffle", new ConfigProvider() {
+
+ @Override
+ public String get(String name) {
+ if ("spark.shuffle.io.connectionTimeout".equals(name)) {
+ // We should make sure there is enough time for us to observe the channel is active
+ return "1s";
+ }
+ String value = System.getProperty(name);
+ if (value == null) {
+ throw new NoSuchElementException(name);
+ }
+ return value;
+ }
+ });
+ TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true);
+ TransportClientFactory factory = context.createClientFactory();
+ try {
+ TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ assertTrue(c1.isActive());
+ long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds
+ while (c1.isActive() && System.currentTimeMillis() < expiredTime) {
+ Thread.sleep(10);
+ }
+ assertFalse(c1.isActive());
+ } finally {
+ factory.close();
+ }
+ }
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
new file mode 100644
index 0000000000..128f7cba74
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.nio.ByteBuffer;
+
+import io.netty.channel.Channel;
+import io.netty.channel.local.LocalChannel;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.StreamCallback;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamResponse;
+import org.apache.spark.network.util.TransportFrameDecoder;
+
+public class TransportResponseHandlerSuite {
+ @Test
+ public void handleSuccessfulFetch() throws Exception {
+ StreamChunkId streamChunkId = new StreamChunkId(1, 0);
+
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ handler.addFetchRequest(streamChunkId, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123)));
+ verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void handleFailedFetch() throws Exception {
+ StreamChunkId streamChunkId = new StreamChunkId(1, 0);
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ handler.addFetchRequest(streamChunkId, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg"));
+ verify(callback, times(1)).onFailure(eq(0), (Throwable) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void clearAllOutstandingRequests() throws Exception {
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ handler.addFetchRequest(new StreamChunkId(1, 0), callback);
+ handler.addFetchRequest(new StreamChunkId(1, 1), callback);
+ handler.addFetchRequest(new StreamChunkId(1, 2), callback);
+ assertEquals(3, handler.numOutstandingRequests());
+
+ handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12)));
+ handler.exceptionCaught(new Exception("duh duh duhhhh"));
+
+ // should fail both b2 and b3
+ verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any());
+ verify(callback, times(1)).onFailure(eq(1), (Throwable) any());
+ verify(callback, times(1)).onFailure(eq(2), (Throwable) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void handleSuccessfulRPC() throws Exception {
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+ handler.addRpcRequest(12345, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ // This response should be ignored.
+ handler.handle(new RpcResponse(54321, new NioManagedBuffer(ByteBuffer.allocate(7))));
+ assertEquals(1, handler.numOutstandingRequests());
+
+ ByteBuffer resp = ByteBuffer.allocate(10);
+ handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp)));
+ verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10)));
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void handleFailedRPC() throws Exception {
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+ handler.addRpcRequest(12345, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new RpcFailure(54321, "uh-oh!")); // should be ignored
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new RpcFailure(12345, "oh no"));
+ verify(callback, times(1)).onFailure((Throwable) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void testActiveStreams() throws Exception {
+ Channel c = new LocalChannel();
+ c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
+ TransportResponseHandler handler = new TransportResponseHandler(c);
+
+ StreamResponse response = new StreamResponse("stream", 1234L, null);
+ StreamCallback cb = mock(StreamCallback.class);
+ handler.addStreamCallback(cb);
+ assertEquals(1, handler.numOutstandingRequests());
+ handler.handle(response);
+ assertEquals(1, handler.numOutstandingRequests());
+ handler.deactivateStream();
+ assertEquals(0, handler.numOutstandingRequests());
+
+ StreamFailure failure = new StreamFailure("stream", "uh-oh");
+ handler.addStreamCallback(cb);
+ assertEquals(1, handler.numOutstandingRequests());
+ handler.handle(failure);
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
new file mode 100644
index 0000000000..fbbe4b7014
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.FileRegion;
+import io.netty.util.AbstractReferenceCounted;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.TestManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+
+public class MessageWithHeaderSuite {
+
+ @Test
+ public void testSingleWrite() throws Exception {
+ testFileRegionBody(8, 8);
+ }
+
+ @Test
+ public void testShortWrite() throws Exception {
+ testFileRegionBody(8, 1);
+ }
+
+ @Test
+ public void testByteBufBody() throws Exception {
+ ByteBuf header = Unpooled.copyLong(42);
+ ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84);
+ assertEquals(1, header.refCnt());
+ assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt());
+ ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer);
+
+ Object body = managedBuf.convertToNetty();
+ assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt());
+ assertEquals(1, header.refCnt());
+
+ MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size());
+ ByteBuf result = doWrite(msg, 1);
+ assertEquals(msg.count(), result.readableBytes());
+ assertEquals(42, result.readLong());
+ assertEquals(84, result.readLong());
+
+ assert(msg.release());
+ assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt());
+ assertEquals(0, header.refCnt());
+ }
+
+ @Test
+ public void testDeallocateReleasesManagedBuffer() throws Exception {
+ ByteBuf header = Unpooled.copyLong(42);
+ ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84));
+ ByteBuf body = (ByteBuf) managedBuf.convertToNetty();
+ assertEquals(2, body.refCnt());
+ MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes());
+ assert(msg.release());
+ Mockito.verify(managedBuf, Mockito.times(1)).release();
+ assertEquals(0, body.refCnt());
+ }
+
+ private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception {
+ ByteBuf header = Unpooled.copyLong(42);
+ int headerLength = header.readableBytes();
+ TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall);
+ MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count());
+
+ ByteBuf result = doWrite(msg, totalWrites / writesPerCall);
+ assertEquals(headerLength + region.count(), result.readableBytes());
+ assertEquals(42, result.readLong());
+ for (long i = 0; i < 8; i++) {
+ assertEquals(i, result.readLong());
+ }
+ assert(msg.release());
+ }
+
+ private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception {
+ int writes = 0;
+ ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count());
+ while (msg.transfered() < msg.count()) {
+ msg.transferTo(channel, msg.transfered());
+ writes++;
+ }
+ assertTrue("Not enough writes!", minExpectedWrites <= writes);
+ return Unpooled.wrappedBuffer(channel.getData());
+ }
+
+ private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion {
+
+ private final int writeCount;
+ private final int writesPerCall;
+ private int written;
+
+ TestFileRegion(int totalWrites, int writesPerCall) {
+ this.writeCount = totalWrites;
+ this.writesPerCall = writesPerCall;
+ }
+
+ @Override
+ public long count() {
+ return 8 * writeCount;
+ }
+
+ @Override
+ public long position() {
+ return 0;
+ }
+
+ @Override
+ public long transfered() {
+ return 8 * written;
+ }
+
+ @Override
+ public long transferTo(WritableByteChannel target, long position) throws IOException {
+ for (int i = 0; i < writesPerCall; i++) {
+ ByteBuf buf = Unpooled.copyLong((position / 8) + i);
+ ByteBuffer nio = buf.nioBuffer();
+ while (nio.remaining() > 0) {
+ target.write(nio);
+ }
+ buf.release();
+ written++;
+ }
+ return 8 * writesPerCall;
+ }
+
+ @Override
+ protected void deallocate() {
+ }
+
+ }
+
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
new file mode 100644
index 0000000000..045773317a
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import java.io.File;
+import java.lang.reflect.Method;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import javax.security.sasl.SaslException;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.ByteStreams;
+import com.google.common.io.Files;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPromise;
+import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes.
+ */
+public class SparkSaslSuite {
+
+ /** Provides a secret key holder which returns secret key == appId */
+ private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() {
+ @Override
+ public String getSaslUser(String appId) {
+ return "user";
+ }
+
+ @Override
+ public String getSecretKey(String appId) {
+ return appId;
+ }
+ };
+
+ @Test
+ public void testMatching() {
+ SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder, false);
+ SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder, false);
+
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+
+ byte[] clientMessage = client.firstToken();
+
+ while (!client.isComplete()) {
+ clientMessage = client.response(server.response(clientMessage));
+ }
+ assertTrue(server.isComplete());
+
+ // Disposal should invalidate
+ server.dispose();
+ assertFalse(server.isComplete());
+ client.dispose();
+ assertFalse(client.isComplete());
+ }
+
+ @Test
+ public void testNonMatching() {
+ SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder, false);
+ SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder, false);
+
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+
+ byte[] clientMessage = client.firstToken();
+
+ try {
+ while (!client.isComplete()) {
+ clientMessage = client.response(server.response(clientMessage));
+ }
+ fail("Should not have completed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage().contains("Mismatched response"));
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+ }
+ }
+
+ @Test
+ public void testSaslAuthentication() throws Throwable {
+ testBasicSasl(false);
+ }
+
+ @Test
+ public void testSaslEncryption() throws Throwable {
+ testBasicSasl(true);
+ }
+
+ private void testBasicSasl(boolean encrypt) throws Throwable {
+ RpcHandler rpcHandler = mock(RpcHandler.class);
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocation) {
+ ByteBuffer message = (ByteBuffer) invocation.getArguments()[1];
+ RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2];
+ assertEquals("Ping", JavaUtils.bytesToString(message));
+ cb.onSuccess(JavaUtils.stringToBytes("Pong"));
+ return null;
+ }
+ })
+ .when(rpcHandler)
+ .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class));
+
+ SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
+ try {
+ ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
+ TimeUnit.SECONDS.toMillis(10));
+ assertEquals("Pong", JavaUtils.bytesToString(response));
+ } finally {
+ ctx.close();
+ // There should be 2 terminated events; one for the client, one for the server.
+ Throwable error = null;
+ long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
+ while (deadline > System.nanoTime()) {
+ try {
+ verify(rpcHandler, times(2)).channelInactive(any(TransportClient.class));
+ error = null;
+ break;
+ } catch (Throwable t) {
+ error = t;
+ TimeUnit.MILLISECONDS.sleep(10);
+ }
+ }
+ if (error != null) {
+ throw error;
+ }
+ }
+ }
+
+ @Test
+ public void testEncryptedMessage() throws Exception {
+ SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
+ byte[] data = new byte[1024];
+ new Random().nextBytes(data);
+ when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
+
+ ByteBuf msg = Unpooled.buffer();
+ try {
+ msg.writeBytes(data);
+
+ // Create a channel with a really small buffer compared to the data. This means that on each
+ // call, the outbound data will not be fully written, so the write() method should return a
+ // dummy count to keep the channel alive when possible.
+ ByteArrayWritableChannel channel = new ByteArrayWritableChannel(32);
+
+ SaslEncryption.EncryptedMessage emsg =
+ new SaslEncryption.EncryptedMessage(backend, msg, 1024);
+ long count = emsg.transferTo(channel, emsg.transfered());
+ assertTrue(count < data.length);
+ assertTrue(count > 0);
+
+ // Here, the output buffer is full so nothing should be transferred.
+ assertEquals(0, emsg.transferTo(channel, emsg.transfered()));
+
+ // Now there's room in the buffer, but not enough to transfer all the remaining data,
+ // so the dummy count should be returned.
+ channel.reset();
+ assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
+
+ // Eventually, the whole message should be transferred.
+ for (int i = 0; i < data.length / 32 - 2; i++) {
+ channel.reset();
+ assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
+ }
+
+ channel.reset();
+ count = emsg.transferTo(channel, emsg.transfered());
+ assertTrue("Unexpected count: " + count, count > 1 && count < data.length);
+ assertEquals(data.length, emsg.transfered());
+ } finally {
+ msg.release();
+ }
+ }
+
+ @Test
+ public void testEncryptedMessageChunking() throws Exception {
+ File file = File.createTempFile("sasltest", ".txt");
+ try {
+ TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+
+ byte[] data = new byte[8 * 1024];
+ new Random().nextBytes(data);
+ Files.write(data, file);
+
+ SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
+ // It doesn't really matter what we return here, as long as it's not null.
+ when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
+
+ FileSegmentManagedBuffer msg = new FileSegmentManagedBuffer(conf, file, 0, file.length());
+ SaslEncryption.EncryptedMessage emsg =
+ new SaslEncryption.EncryptedMessage(backend, msg.convertToNetty(), data.length / 8);
+
+ ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length);
+ while (emsg.transfered() < emsg.count()) {
+ channel.reset();
+ emsg.transferTo(channel, emsg.transfered());
+ }
+
+ verify(backend, times(8)).wrap(any(byte[].class), anyInt(), anyInt());
+ } finally {
+ file.delete();
+ }
+ }
+
+ @Test
+ public void testFileRegionEncryption() throws Exception {
+ final String blockSizeConf = "spark.network.sasl.maxEncryptedBlockSize";
+ System.setProperty(blockSizeConf, "1k");
+
+ final AtomicReference<ManagedBuffer> response = new AtomicReference<>();
+ final File file = File.createTempFile("sasltest", ".txt");
+ SaslTestCtx ctx = null;
+ try {
+ final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+ StreamManager sm = mock(StreamManager.class);
+ when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer<ManagedBuffer>() {
+ @Override
+ public ManagedBuffer answer(InvocationOnMock invocation) {
+ return new FileSegmentManagedBuffer(conf, file, 0, file.length());
+ }
+ });
+
+ RpcHandler rpcHandler = mock(RpcHandler.class);
+ when(rpcHandler.getStreamManager()).thenReturn(sm);
+
+ byte[] data = new byte[8 * 1024];
+ new Random().nextBytes(data);
+ Files.write(data, file);
+
+ ctx = new SaslTestCtx(rpcHandler, true, false);
+
+ final Object lock = new Object();
+
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocation) {
+ response.set((ManagedBuffer) invocation.getArguments()[1]);
+ response.get().retain();
+ synchronized (lock) {
+ lock.notifyAll();
+ }
+ return null;
+ }
+ }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
+
+ synchronized (lock) {
+ ctx.client.fetchChunk(0, 0, callback);
+ lock.wait(10 * 1000);
+ }
+
+ verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
+ verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
+
+ byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
+ assertTrue(Arrays.equals(data, received));
+ } finally {
+ file.delete();
+ if (ctx != null) {
+ ctx.close();
+ }
+ if (response.get() != null) {
+ response.get().release();
+ }
+ System.clearProperty(blockSizeConf);
+ }
+ }
+
+ @Test
+ public void testServerAlwaysEncrypt() throws Exception {
+ final String alwaysEncryptConfName = "spark.network.sasl.serverAlwaysEncrypt";
+ System.setProperty(alwaysEncryptConfName, "true");
+
+ SaslTestCtx ctx = null;
+ try {
+ ctx = new SaslTestCtx(mock(RpcHandler.class), false, false);
+ fail("Should have failed to connect without encryption.");
+ } catch (Exception e) {
+ assertTrue(e.getCause() instanceof SaslException);
+ } finally {
+ if (ctx != null) {
+ ctx.close();
+ }
+ System.clearProperty(alwaysEncryptConfName);
+ }
+ }
+
+ @Test
+ public void testDataEncryptionIsActuallyEnabled() throws Exception {
+ // This test sets up an encrypted connection but then, using a client bootstrap, removes
+ // the encryption handler from the client side. This should cause the server to not be
+ // able to understand RPCs sent to it and thus close the connection.
+ SaslTestCtx ctx = null;
+ try {
+ ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
+ ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
+ TimeUnit.SECONDS.toMillis(10));
+ fail("Should have failed to send RPC to server.");
+ } catch (Exception e) {
+ assertFalse(e.getCause() instanceof TimeoutException);
+ } finally {
+ if (ctx != null) {
+ ctx.close();
+ }
+ }
+ }
+
+ @Test
+ public void testRpcHandlerDelegate() throws Exception {
+ // Tests all delegates exception for receive(), which is more complicated and already handled
+ // by all other tests.
+ RpcHandler handler = mock(RpcHandler.class);
+ RpcHandler saslHandler = new SaslRpcHandler(null, null, handler, null);
+
+ saslHandler.getStreamManager();
+ verify(handler).getStreamManager();
+
+ saslHandler.channelInactive(null);
+ verify(handler).channelInactive(any(TransportClient.class));
+
+ saslHandler.exceptionCaught(null, null);
+ verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class));
+ }
+
+ @Test
+ public void testDelegates() throws Exception {
+ Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods();
+ for (Method m : rpcHandlerMethods) {
+ SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes());
+ }
+ }
+
+ private static class SaslTestCtx {
+
+ final TransportClient client;
+ final TransportServer server;
+
+ private final boolean encrypt;
+ private final boolean disableClientEncryption;
+ private final EncryptionCheckerBootstrap checker;
+
+ SaslTestCtx(
+ RpcHandler rpcHandler,
+ boolean encrypt,
+ boolean disableClientEncryption)
+ throws Exception {
+
+ TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+
+ SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
+ when(keyHolder.getSaslUser(anyString())).thenReturn("user");
+ when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
+
+ TransportContext ctx = new TransportContext(conf, rpcHandler);
+
+ this.checker = new EncryptionCheckerBootstrap();
+ this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder),
+ checker));
+
+ try {
+ List<TransportClientBootstrap> clientBootstraps = Lists.newArrayList();
+ clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt));
+ if (disableClientEncryption) {
+ clientBootstraps.add(new EncryptionDisablerBootstrap());
+ }
+
+ this.client = ctx.createClientFactory(clientBootstraps)
+ .createClient(TestUtils.getLocalHost(), server.getPort());
+ } catch (Exception e) {
+ close();
+ throw e;
+ }
+
+ this.encrypt = encrypt;
+ this.disableClientEncryption = disableClientEncryption;
+ }
+
+ void close() {
+ if (!disableClientEncryption) {
+ assertEquals(encrypt, checker.foundEncryptionHandler);
+ }
+ if (client != null) {
+ client.close();
+ }
+ if (server != null) {
+ server.close();
+ }
+ }
+
+ }
+
+ private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAdapter
+ implements TransportServerBootstrap {
+
+ boolean foundEncryptionHandler;
+
+ @Override
+ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
+ throws Exception {
+ if (!foundEncryptionHandler) {
+ foundEncryptionHandler =
+ ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null;
+ }
+ ctx.write(msg, promise);
+ }
+
+ @Override
+ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
+ super.handlerRemoved(ctx);
+ }
+
+ @Override
+ public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
+ channel.pipeline().addFirst("encryptionChecker", this);
+ return rpcHandler;
+ }
+
+ }
+
+ private static class EncryptionDisablerBootstrap implements TransportClientBootstrap {
+
+ @Override
+ public void doBootstrap(TransportClient client, Channel channel) {
+ channel.pipeline().remove(SaslEncryption.ENCRYPTION_HANDLER_NAME);
+ }
+
+ }
+
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
new file mode 100644
index 0000000000..c647525d8f
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import io.netty.channel.Channel;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import org.apache.spark.network.TestManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+public class OneForOneStreamManagerSuite {
+
+ @Test
+ public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception {
+ OneForOneStreamManager manager = new OneForOneStreamManager();
+ List<ManagedBuffer> buffers = new ArrayList<>();
+ TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10));
+ TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
+ buffers.add(buffer1);
+ buffers.add(buffer2);
+ long streamId = manager.registerStream("appId", buffers.iterator());
+
+ Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
+ manager.registerChannel(dummyChannel, streamId);
+
+ manager.connectionTerminated(dummyChannel);
+
+ Mockito.verify(buffer1, Mockito.times(1)).release();
+ Mockito.verify(buffer2, Mockito.times(1)).release();
+ }
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
new file mode 100644
index 0000000000..d4de4a941d
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
@@ -0,0 +1,258 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import org.junit.AfterClass;
+import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+public class TransportFrameDecoderSuite {
+
+ private static Random RND = new Random();
+
+ @AfterClass
+ public static void cleanup() {
+ RND = null;
+ }
+
+ @Test
+ public void testFrameDecoding() throws Exception {
+ TransportFrameDecoder decoder = new TransportFrameDecoder();
+ ChannelHandlerContext ctx = mockChannelHandlerContext();
+ ByteBuf data = createAndFeedFrames(100, decoder, ctx);
+ verifyAndCloseDecoder(decoder, ctx, data);
+ }
+
+ @Test
+ public void testInterception() throws Exception {
+ final int interceptedReads = 3;
+ TransportFrameDecoder decoder = new TransportFrameDecoder();
+ TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads));
+ ChannelHandlerContext ctx = mockChannelHandlerContext();
+
+ byte[] data = new byte[8];
+ ByteBuf len = Unpooled.copyLong(8 + data.length);
+ ByteBuf dataBuf = Unpooled.wrappedBuffer(data);
+
+ try {
+ decoder.setInterceptor(interceptor);
+ for (int i = 0; i < interceptedReads; i++) {
+ decoder.channelRead(ctx, dataBuf);
+ assertEquals(0, dataBuf.refCnt());
+ dataBuf = Unpooled.wrappedBuffer(data);
+ }
+ decoder.channelRead(ctx, len);
+ decoder.channelRead(ctx, dataBuf);
+ verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class));
+ verify(ctx).fireChannelRead(any(ByteBuffer.class));
+ assertEquals(0, len.refCnt());
+ assertEquals(0, dataBuf.refCnt());
+ } finally {
+ release(len);
+ release(dataBuf);
+ }
+ }
+
+ @Test
+ public void testRetainedFrames() throws Exception {
+ TransportFrameDecoder decoder = new TransportFrameDecoder();
+
+ final AtomicInteger count = new AtomicInteger();
+ final List<ByteBuf> retained = new ArrayList<>();
+
+ ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+ when(ctx.fireChannelRead(any())).thenAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock in) {
+ // Retain a few frames but not others.
+ ByteBuf buf = (ByteBuf) in.getArguments()[0];
+ if (count.incrementAndGet() % 2 == 0) {
+ retained.add(buf);
+ } else {
+ buf.release();
+ }
+ return null;
+ }
+ });
+
+ ByteBuf data = createAndFeedFrames(100, decoder, ctx);
+ try {
+ // Verify all retained buffers are readable.
+ for (ByteBuf b : retained) {
+ byte[] tmp = new byte[b.readableBytes()];
+ b.readBytes(tmp);
+ b.release();
+ }
+ verifyAndCloseDecoder(decoder, ctx, data);
+ } finally {
+ for (ByteBuf b : retained) {
+ release(b);
+ }
+ }
+ }
+
+ @Test
+ public void testSplitLengthField() throws Exception {
+ byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
+ ByteBuf buf = Unpooled.buffer(frame.length + 8);
+ buf.writeLong(frame.length + 8);
+ buf.writeBytes(frame);
+
+ TransportFrameDecoder decoder = new TransportFrameDecoder();
+ ChannelHandlerContext ctx = mockChannelHandlerContext();
+ try {
+ decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain());
+ verify(ctx, never()).fireChannelRead(any(ByteBuf.class));
+ decoder.channelRead(ctx, buf);
+ verify(ctx).fireChannelRead(any(ByteBuf.class));
+ assertEquals(0, buf.refCnt());
+ } finally {
+ decoder.channelInactive(ctx);
+ release(buf);
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testNegativeFrameSize() throws Exception {
+ testInvalidFrame(-1);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testEmptyFrame() throws Exception {
+ // 8 because frame size includes the frame length.
+ testInvalidFrame(8);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testLargeFrame() throws Exception {
+ // Frame length includes the frame size field, so need to add a few more bytes.
+ testInvalidFrame(Integer.MAX_VALUE + 9);
+ }
+
+ /**
+ * Creates a number of randomly sized frames and feed them to the given decoder, verifying
+ * that the frames were read.
+ */
+ private ByteBuf createAndFeedFrames(
+ int frameCount,
+ TransportFrameDecoder decoder,
+ ChannelHandlerContext ctx) throws Exception {
+ ByteBuf data = Unpooled.buffer();
+ for (int i = 0; i < frameCount; i++) {
+ byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
+ data.writeLong(frame.length + 8);
+ data.writeBytes(frame);
+ }
+
+ try {
+ while (data.isReadable()) {
+ int size = RND.nextInt(4 * 1024) + 256;
+ decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain());
+ }
+
+ verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class));
+ } catch (Exception e) {
+ release(data);
+ throw e;
+ }
+ return data;
+ }
+
+ private void verifyAndCloseDecoder(
+ TransportFrameDecoder decoder,
+ ChannelHandlerContext ctx,
+ ByteBuf data) throws Exception {
+ try {
+ decoder.channelInactive(ctx);
+ assertTrue("There shouldn't be dangling references to the data.", data.release());
+ } finally {
+ release(data);
+ }
+ }
+
+ private void testInvalidFrame(long size) throws Exception {
+ TransportFrameDecoder decoder = new TransportFrameDecoder();
+ ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+ ByteBuf frame = Unpooled.copyLong(size);
+ try {
+ decoder.channelRead(ctx, frame);
+ } finally {
+ release(frame);
+ }
+ }
+
+ private ChannelHandlerContext mockChannelHandlerContext() {
+ ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+ when(ctx.fireChannelRead(any())).thenAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock in) {
+ ByteBuf buf = (ByteBuf) in.getArguments()[0];
+ buf.release();
+ return null;
+ }
+ });
+ return ctx;
+ }
+
+ private void release(ByteBuf buf) {
+ if (buf.refCnt() > 0) {
+ buf.release(buf.refCnt());
+ }
+ }
+
+ private static class MockInterceptor implements TransportFrameDecoder.Interceptor {
+
+ private int remainingReads;
+
+ MockInterceptor(int readCount) {
+ this.remainingReads = readCount;
+ }
+
+ @Override
+ public boolean handle(ByteBuf data) throws Exception {
+ data.readerIndex(data.readerIndex() + data.readableBytes());
+ assertFalse(data.isReadable());
+ remainingReads -= 1;
+ return remainingReads != 0;
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) throws Exception {
+
+ }
+
+ @Override
+ public void channelInactive() throws Exception {
+
+ }
+
+ }
+
+}
diff --git a/common/network-common/src/test/resources/log4j.properties b/common/network-common/src/test/resources/log4j.properties
new file mode 100644
index 0000000000..e8da774f7c
--- /dev/null
+++ b/common/network-common/src/test/resources/log4j.properties
@@ -0,0 +1,27 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file target/unit-tests.log
+log4j.rootCategory=DEBUG, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=true
+log4j.appender.file.file=target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+
+# Silence verbose logs from 3rd-party libraries.
+log4j.logger.io.netty=INFO