aboutsummaryrefslogtreecommitdiff
path: root/network
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2015-11-04 09:11:54 -0800
committerMarcelo Vanzin <vanzin@cloudera.com>2015-11-04 09:11:54 -0800
commit27feafccbd6945b000ca51b14c57912acbad9031 (patch)
treed32ac36287d2f82afffa44c792c5487d07ef182d /network
parent8790ee6d69e50ca84eb849742be48f2476743b5b (diff)
downloadspark-27feafccbd6945b000ca51b14c57912acbad9031.tar.gz
spark-27feafccbd6945b000ca51b14c57912acbad9031.tar.bz2
spark-27feafccbd6945b000ca51b14c57912acbad9031.zip
[SPARK-11235][NETWORK] Add ability to stream data using network lib.
The current interface used to fetch shuffle data is not very efficient for large buffers; it requires the receiver to buffer the entirety of the contents being downloaded in memory before processing the data. To use the network library to transfer large files (such as those that can be added using SparkContext addJar / addFile), this change adds a more efficient way of downloding data, by streaming the data and feeding it to a callback as data arrives. This is achieved by a custom frame decoder that replaces the current netty one; this decoder allows entering a mode where framing is skipped and data is instead provided directly to a callback. The existing netty classes (ByteToMessageDecoder and LengthFieldBasedFrameDecoder) could not be reused since their semantics do not allow for the interception approach the new decoder uses. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #9206 from vanzin/SPARK-11235.
Diffstat (limited to 'network')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/TransportContext.java3
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java40
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java76
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java41
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java47
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java16
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Message.java6
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java27
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java40
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java80
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java78
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java91
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/StreamManager.java13
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java20
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java154
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java8
-rw-r--r--network/common/src/test/java/org/apache/spark/network/StreamSuite.java325
-rw-r--r--network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java142
20 files changed, 1196 insertions, 29 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index b8d073fa16..43900e6f2c 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -39,6 +39,7 @@ import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.util.TransportFrameDecoder;
/**
* Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to
@@ -119,7 +120,7 @@ public class TransportContext {
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
channel.pipeline()
.addLast("encoder", encoder)
- .addLast("frameDecoder", NettyUtils.createFrameDecoder())
+ .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
.addLast("decoder", decoder)
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java
new file mode 100644
index 0000000000..093fada320
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+/**
+ * Callback for streaming data. Stream data will be offered to the {@link onData(ByteBuffer)}
+ * method as it arrives. Once all the stream data is received, {@link onComplete()} will be
+ * called.
+ * <p>
+ * The network library guarantees that a single thread will call these methods at a time, but
+ * different call may be made by different threads.
+ */
+public interface StreamCallback {
+ /** Called upon receipt of stream data. */
+ void onData(String streamId, ByteBuffer buf) throws IOException;
+
+ /** Called when all data from the stream has been received. */
+ void onComplete(String streamId) throws IOException;
+
+ /** Called if there's an error reading data from the stream. */
+ void onFailure(String streamId, Throwable cause) throws IOException;
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
new file mode 100644
index 0000000000..02230a00e6
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.nio.ByteBuffer;
+import java.nio.channels.ClosedChannelException;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.util.TransportFrameDecoder;
+
+/**
+ * An interceptor that is registered with the frame decoder to feed stream data to a
+ * callback.
+ */
+class StreamInterceptor implements TransportFrameDecoder.Interceptor {
+
+ private final String streamId;
+ private final long byteCount;
+ private final StreamCallback callback;
+
+ private volatile long bytesRead;
+
+ StreamInterceptor(String streamId, long byteCount, StreamCallback callback) {
+ this.streamId = streamId;
+ this.byteCount = byteCount;
+ this.callback = callback;
+ this.bytesRead = 0;
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) throws Exception {
+ callback.onFailure(streamId, cause);
+ }
+
+ @Override
+ public void channelInactive() throws Exception {
+ callback.onFailure(streamId, new ClosedChannelException());
+ }
+
+ @Override
+ public boolean handle(ByteBuf buf) throws Exception {
+ int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead);
+ ByteBuffer nioBuffer = buf.readSlice(toRead).nioBuffer();
+
+ int available = nioBuffer.remaining();
+ callback.onData(streamId, nioBuffer);
+ bytesRead += available;
+ if (bytesRead > byteCount) {
+ RuntimeException re = new IllegalStateException(String.format(
+ "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead));
+ callback.onFailure(streamId, re);
+ throw re;
+ } else if (bytesRead == byteCount) {
+ callback.onComplete(streamId);
+ }
+
+ return bytesRead != byteCount;
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index fbb8bb6b2f..a0ba223e34 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -38,6 +38,7 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.protocol.ChunkFetchRequest;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamRequest;
import org.apache.spark.network.util.NettyUtils;
/**
@@ -160,6 +161,46 @@ public class TransportClient implements Closeable {
}
/**
+ * Request to stream the data with the given stream ID from the remote end.
+ *
+ * @param streamId The stream to fetch.
+ * @param callback Object to call with the stream data.
+ */
+ public void stream(final String streamId, final StreamCallback callback) {
+ final String serverAddr = NettyUtils.getRemoteAddress(channel);
+ final long startTime = System.currentTimeMillis();
+ logger.debug("Sending stream request for {} to {}", streamId, serverAddr);
+
+ // Need to synchronize here so that the callback is added to the queue and the RPC is
+ // written to the socket atomically, so that callbacks are called in the right order
+ // when responses arrive.
+ synchronized (this) {
+ handler.addStreamCallback(callback);
+ channel.writeAndFlush(new StreamRequest(streamId)).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.trace("Sending request for {} to {} took {} ms", streamId, serverAddr,
+ timeTaken);
+ } else {
+ String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId,
+ serverAddr, future.cause());
+ logger.error(errorMsg, future.cause());
+ channel.close();
+ try {
+ callback.onFailure(streamId, new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
+ }
+ }
+ });
+ }
+ }
+
+ /**
* Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked
* with the server's response or upon any failure.
*/
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index 94fc21af5e..ed3f36af58 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -19,7 +19,9 @@ package org.apache.spark.network.client;
import java.io.IOException;
import java.util.Map;
+import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
import io.netty.channel.Channel;
@@ -32,8 +34,11 @@ import org.apache.spark.network.protocol.ResponseMessage;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamResponse;
import org.apache.spark.network.server.MessageHandler;
import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportFrameDecoder;
/**
* Handler that processes server responses, in response to requests issued from a
@@ -50,6 +55,8 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
private final Map<Long, RpcResponseCallback> outstandingRpcs;
+ private final Queue<StreamCallback> streamCallbacks;
+
/** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
private final AtomicLong timeOfLastRequestNs;
@@ -57,6 +64,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
this.channel = channel;
this.outstandingFetches = new ConcurrentHashMap<StreamChunkId, ChunkReceivedCallback>();
this.outstandingRpcs = new ConcurrentHashMap<Long, RpcResponseCallback>();
+ this.streamCallbacks = new ConcurrentLinkedQueue<StreamCallback>();
this.timeOfLastRequestNs = new AtomicLong(0);
}
@@ -78,6 +86,10 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
outstandingRpcs.remove(requestId);
}
+ public void addStreamCallback(StreamCallback callback) {
+ streamCallbacks.offer(callback);
+ }
+
/**
* Fire the failure callback for all outstanding requests. This is called when we have an
* uncaught exception or pre-mature connection termination.
@@ -124,11 +136,11 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
if (listener == null) {
logger.warn("Ignoring response for block {} from {} since it is not outstanding",
resp.streamChunkId, remoteAddress);
- resp.buffer.release();
+ resp.body.release();
} else {
outstandingFetches.remove(resp.streamChunkId);
- listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer);
- resp.buffer.release();
+ listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body);
+ resp.body.release();
}
} else if (message instanceof ChunkFetchFailure) {
ChunkFetchFailure resp = (ChunkFetchFailure) message;
@@ -161,6 +173,34 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
outstandingRpcs.remove(resp.requestId);
listener.onFailure(new RuntimeException(resp.errorString));
}
+ } else if (message instanceof StreamResponse) {
+ StreamResponse resp = (StreamResponse) message;
+ StreamCallback callback = streamCallbacks.poll();
+ if (callback != null) {
+ StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, resp.byteCount,
+ callback);
+ try {
+ TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
+ channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
+ frameDecoder.setInterceptor(interceptor);
+ } catch (Exception e) {
+ logger.error("Error installing stream handler.", e);
+ }
+ } else {
+ logger.error("Could not find callback for StreamResponse.");
+ }
+ } else if (message instanceof StreamFailure) {
+ StreamFailure resp = (StreamFailure) message;
+ StreamCallback callback = streamCallbacks.poll();
+ if (callback != null) {
+ try {
+ callback.onFailure(resp.streamId, new RuntimeException(resp.error));
+ } catch (IOException ioe) {
+ logger.warn("Error in stream failure handler.", ioe);
+ }
+ } else {
+ logger.warn("Stream failure with unknown callback: {}", resp.error);
+ }
} else {
throw new IllegalStateException("Unknown response type: " + message.type());
}
@@ -175,4 +215,5 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
public long getTimeOfLastRequestNs() {
return timeOfLastRequestNs.get();
}
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
index c962fb7ecf..e6a7e9a8b4 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
@@ -30,13 +30,12 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
* may be written by Netty in a more efficient manner (i.e., zero-copy write).
* Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer.
*/
-public final class ChunkFetchSuccess implements ResponseMessage {
+public final class ChunkFetchSuccess extends ResponseWithBody {
public final StreamChunkId streamChunkId;
- public final ManagedBuffer buffer;
public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
+ super(buffer, true);
this.streamChunkId = streamChunkId;
- this.buffer = buffer;
}
@Override
@@ -53,6 +52,11 @@ public final class ChunkFetchSuccess implements ResponseMessage {
streamChunkId.encode(buf);
}
+ @Override
+ public ResponseMessage createFailureResponse(String error) {
+ return new ChunkFetchFailure(streamChunkId, error);
+ }
+
/** Decoding uses the given ByteBuf as our data, and will retain() it. */
public static ChunkFetchSuccess decode(ByteBuf buf) {
StreamChunkId streamChunkId = StreamChunkId.decode(buf);
@@ -63,14 +67,14 @@ public final class ChunkFetchSuccess implements ResponseMessage {
@Override
public int hashCode() {
- return Objects.hashCode(streamChunkId, buffer);
+ return Objects.hashCode(streamChunkId, body);
}
@Override
public boolean equals(Object other) {
if (other instanceof ChunkFetchSuccess) {
ChunkFetchSuccess o = (ChunkFetchSuccess) other;
- return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer);
+ return streamChunkId.equals(o.streamChunkId) && body.equals(o.body);
}
return false;
}
@@ -79,7 +83,7 @@ public final class ChunkFetchSuccess implements ResponseMessage {
public String toString() {
return Objects.toStringHelper(this)
.add("streamChunkId", streamChunkId)
- .add("buffer", buffer)
+ .add("buffer", body)
.toString();
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
index d568370125..d01598c20f 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
@@ -27,7 +27,8 @@ public interface Message extends Encodable {
/** Preceding every serialized Message is its type, which allows us to deserialize it. */
public static enum Type implements Encodable {
ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
- RpcRequest(3), RpcResponse(4), RpcFailure(5);
+ RpcRequest(3), RpcResponse(4), RpcFailure(5),
+ StreamRequest(6), StreamResponse(7), StreamFailure(8);
private final byte id;
@@ -51,6 +52,9 @@ public interface Message extends Encodable {
case 3: return RpcRequest;
case 4: return RpcResponse;
case 5: return RpcFailure;
+ case 6: return StreamRequest;
+ case 7: return StreamResponse;
+ case 8: return StreamFailure;
default: throw new IllegalArgumentException("Unknown message type: " + id);
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
index 81f8d7f963..3c04048f38 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
@@ -63,6 +63,15 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
case RpcFailure:
return RpcFailure.decode(in);
+ case StreamRequest:
+ return StreamRequest.decode(in);
+
+ case StreamResponse:
+ return StreamResponse.decode(in);
+
+ case StreamFailure:
+ return StreamFailure.decode(in);
+
default:
throw new IllegalArgumentException("Unexpected message type: " + msgType);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 0f999f5dfe..6cce97c807 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -45,27 +45,32 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) {
Object body = null;
long bodyLength = 0;
+ boolean isBodyInFrame = false;
- // Only ChunkFetchSuccesses have data besides the header.
+ // Detect ResponseWithBody messages and get the data buffer out of them.
// The body is used in order to enable zero-copy transfer for the payload.
- if (in instanceof ChunkFetchSuccess) {
- ChunkFetchSuccess resp = (ChunkFetchSuccess) in;
+ if (in instanceof ResponseWithBody) {
+ ResponseWithBody resp = (ResponseWithBody) in;
try {
- bodyLength = resp.buffer.size();
- body = resp.buffer.convertToNetty();
+ bodyLength = resp.body.size();
+ body = resp.body.convertToNetty();
+ isBodyInFrame = resp.isBodyInFrame;
} catch (Exception e) {
- // Re-encode this message as BlockFetchFailure.
- logger.error(String.format("Error opening block %s for client %s",
- resp.streamChunkId, ctx.channel().remoteAddress()), e);
- encode(ctx, new ChunkFetchFailure(resp.streamChunkId, e.getMessage()), out);
+ // Re-encode this message as a failure response.
+ String error = e.getMessage() != null ? e.getMessage() : "null";
+ logger.error(String.format("Error processing %s for client %s",
+ resp, ctx.channel().remoteAddress()), e);
+ encode(ctx, resp.createFailureResponse(error), out);
return;
}
}
Message.Type msgType = in.type();
- // All messages have the frame length, message type, and message itself.
+ // All messages have the frame length, message type, and message itself. The frame length
+ // may optionally include the length of the body data, depending on what message is being
+ // sent.
int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
- long frameLength = headerLength + bodyLength;
+ long frameLength = headerLength + (isBodyInFrame ? bodyLength : 0);
ByteBuf header = ctx.alloc().heapBuffer(headerLength);
header.writeLong(frameLength);
msgType.encode(header);
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java
new file mode 100644
index 0000000000..67be77e39f
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Abstract class for response messages that contain a large data portion kept in a separate
+ * buffer. These messages are treated especially by MessageEncoder.
+ */
+public abstract class ResponseWithBody implements ResponseMessage {
+ public final ManagedBuffer body;
+ public final boolean isBodyInFrame;
+
+ protected ResponseWithBody(ManagedBuffer body, boolean isBodyInFrame) {
+ this.body = body;
+ this.isBodyInFrame = isBodyInFrame;
+ }
+
+ public abstract ResponseMessage createFailureResponse(String error);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
new file mode 100644
index 0000000000..e3dade2ebf
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Message indicating an error when transferring a stream.
+ */
+public final class StreamFailure implements ResponseMessage {
+ public final String streamId;
+ public final String error;
+
+ public StreamFailure(String streamId, String error) {
+ this.streamId = streamId;
+ this.error = error;
+ }
+
+ @Override
+ public Type type() { return Type.StreamFailure; }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(streamId) + Encoders.Strings.encodedLength(error);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, streamId);
+ Encoders.Strings.encode(buf, error);
+ }
+
+ public static StreamFailure decode(ByteBuf buf) {
+ String streamId = Encoders.Strings.decode(buf);
+ String error = Encoders.Strings.decode(buf);
+ return new StreamFailure(streamId, error);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId, error);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamFailure) {
+ StreamFailure o = (StreamFailure) other;
+ return streamId.equals(o.streamId) && error.equals(o.error);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("error", error)
+ .toString();
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
new file mode 100644
index 0000000000..821e8f5388
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Request to stream data from the remote end.
+ * <p>
+ * The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before
+ * the data can be streamed.
+ */
+public final class StreamRequest implements RequestMessage {
+ public final String streamId;
+
+ public StreamRequest(String streamId) {
+ this.streamId = streamId;
+ }
+
+ @Override
+ public Type type() { return Type.StreamRequest; }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(streamId);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, streamId);
+ }
+
+ public static StreamRequest decode(ByteBuf buf) {
+ String streamId = Encoders.Strings.decode(buf);
+ return new StreamRequest(streamId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamRequest) {
+ StreamRequest o = (StreamRequest) other;
+ return streamId.equals(o.streamId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .toString();
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
new file mode 100644
index 0000000000..ac5ab9a323
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Response to {@link StreamRequest} when the stream has been successfully opened.
+ * <p>
+ * Note the message itself does not contain the stream data. That is written separately by the
+ * sender. The receiver is expected to set a temporary channel handler that will consume the
+ * number of bytes this message says the stream has.
+ */
+public final class StreamResponse extends ResponseWithBody {
+ public final String streamId;
+ public final long byteCount;
+
+ public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) {
+ super(buffer, false);
+ this.streamId = streamId;
+ this.byteCount = byteCount;
+ }
+
+ @Override
+ public Type type() { return Type.StreamResponse; }
+
+ @Override
+ public int encodedLength() {
+ return 8 + Encoders.Strings.encodedLength(streamId);
+ }
+
+ /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, streamId);
+ buf.writeLong(byteCount);
+ }
+
+ @Override
+ public ResponseMessage createFailureResponse(String error) {
+ return new StreamFailure(streamId, error);
+ }
+
+ public static StreamResponse decode(ByteBuf buf) {
+ String streamId = Encoders.Strings.decode(buf);
+ long byteCount = buf.readLong();
+ return new StreamResponse(streamId, byteCount, null);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(byteCount, streamId);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamResponse) {
+ StreamResponse o = (StreamResponse) other;
+ return byteCount == o.byteCount && streamId.equals(o.streamId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("byteCount", byteCount)
+ .toString();
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
index aaa677c965..3f0155957a 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -47,6 +47,19 @@ public abstract class StreamManager {
public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);
/**
+ * Called in response to a stream() request. The returned data is streamed to the client
+ * through a single TCP connection.
+ *
+ * Note the <code>streamId</code> argument is not related to the similarly named argument in the
+ * {@link #getChunk(long, int)} method.
+ *
+ * @param streamId id of a stream that has been previously registered with the StreamManager.
+ */
+ public ManagedBuffer openStream(String streamId) {
+ throw new UnsupportedOperationException();
+ }
+
+ /**
* Associates a stream with a single client connection, which is guaranteed to be the only reader
* of the stream. The getChunk() method will be called serially on this connection and once the
* connection is closed, the stream will never be used again, enabling cleanup.
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 9b8b047b49..4f67bd573b 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -35,6 +35,9 @@ import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamRequest;
+import org.apache.spark.network.protocol.StreamResponse;
import org.apache.spark.network.util.NettyUtils;
/**
@@ -92,6 +95,8 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
processFetchRequest((ChunkFetchRequest) request);
} else if (request instanceof RpcRequest) {
processRpcRequest((RpcRequest) request);
+ } else if (request instanceof StreamRequest) {
+ processStreamRequest((StreamRequest) request);
} else {
throw new IllegalArgumentException("Unknown request type: " + request);
}
@@ -117,6 +122,21 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
respond(new ChunkFetchSuccess(req.streamChunkId, buf));
}
+ private void processStreamRequest(final StreamRequest req) {
+ final String client = NettyUtils.getRemoteAddress(channel);
+ ManagedBuffer buf;
+ try {
+ buf = streamManager.openStream(req.streamId);
+ } catch (Exception e) {
+ logger.error(String.format(
+ "Error opening stream %s for request from %s", req.streamId, client), e);
+ respond(new StreamFailure(req.streamId, Throwables.getStackTraceAsString(e)));
+ return;
+ }
+
+ respond(new StreamResponse(req.streamId, buf.size(), buf));
+ }
+
private void processRpcRequest(final RpcRequest req) {
try {
rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() {
diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
index 26c6399ce7..caa7260bc8 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
@@ -89,13 +89,8 @@ public class NettyUtils {
* Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame.
* This is used before all decoders.
*/
- public static ByteToMessageDecoder createFrameDecoder() {
- // maxFrameLength = 2G
- // lengthFieldOffset = 0
- // lengthFieldLength = 8
- // lengthAdjustment = -8, i.e. exclude the 8 byte length itself
- // initialBytesToStrip = 8, i.e. strip out the length field itself
- return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8);
+ public static TransportFrameDecoder createFrameDecoder() {
+ return new TransportFrameDecoder();
}
/** Returns the remote address on the channel or "&lt;unknown remote&gt;" if none exists. */
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
new file mode 100644
index 0000000000..272ea84e61
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+
+/**
+ * A customized frame decoder that allows intercepting raw data.
+ * <p>
+ * This behaves like Netty's frame decoder (with harcoded parameters that match this library's
+ * needs), except it allows an interceptor to be installed to read data directly before it's
+ * framed.
+ * <p>
+ * Unlike Netty's frame decoder, each frame is dispatched to child handlers as soon as it's
+ * decoded, instead of building as many frames as the current buffer allows and dispatching
+ * all of them. This allows a child handler to install an interceptor if needed.
+ * <p>
+ * If an interceptor is installed, framing stops, and data is instead fed directly to the
+ * interceptor. When the interceptor indicates that it doesn't need to read any more data,
+ * framing resumes. Interceptors should not hold references to the data buffers provided
+ * to their handle() method.
+ */
+public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
+
+ public static final String HANDLER_NAME = "frameDecoder";
+ private static final int LENGTH_SIZE = 8;
+ private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE;
+
+ private CompositeByteBuf buffer;
+ private volatile Interceptor interceptor;
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
+ ByteBuf in = (ByteBuf) data;
+
+ if (buffer == null) {
+ buffer = in.alloc().compositeBuffer();
+ }
+
+ buffer.writeBytes(in);
+
+ while (buffer.isReadable()) {
+ feedInterceptor();
+ if (interceptor != null) {
+ continue;
+ }
+
+ ByteBuf frame = decodeNext();
+ if (frame != null) {
+ ctx.fireChannelRead(frame);
+ } else {
+ break;
+ }
+ }
+
+ // We can't discard read sub-buffers if there are other references to the buffer (e.g.
+ // through slices used for framing). This assumes that code that retains references
+ // will call retain() from the thread that called "fireChannelRead()" above, otherwise
+ // ref counting will go awry.
+ if (buffer != null && buffer.refCnt() == 1) {
+ buffer.discardReadComponents();
+ }
+ }
+
+ protected ByteBuf decodeNext() throws Exception {
+ if (buffer.readableBytes() < LENGTH_SIZE) {
+ return null;
+ }
+
+ int frameLen = (int) buffer.readLong() - LENGTH_SIZE;
+ if (buffer.readableBytes() < frameLen) {
+ buffer.readerIndex(buffer.readerIndex() - LENGTH_SIZE);
+ return null;
+ }
+
+ Preconditions.checkArgument(frameLen < MAX_FRAME_SIZE, "Too large frame: %s", frameLen);
+ Preconditions.checkArgument(frameLen > 0, "Frame length should be positive: %s", frameLen);
+
+ ByteBuf frame = buffer.readSlice(frameLen);
+ frame.retain();
+ return frame;
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ if (buffer != null) {
+ if (buffer.isReadable()) {
+ feedInterceptor();
+ }
+ buffer.release();
+ }
+ if (interceptor != null) {
+ interceptor.channelInactive();
+ }
+ super.channelInactive(ctx);
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
+ if (interceptor != null) {
+ interceptor.exceptionCaught(cause);
+ }
+ super.exceptionCaught(ctx, cause);
+ }
+
+ public void setInterceptor(Interceptor interceptor) {
+ Preconditions.checkState(this.interceptor == null, "Already have an interceptor.");
+ this.interceptor = interceptor;
+ }
+
+ private void feedInterceptor() throws Exception {
+ if (interceptor != null && !interceptor.handle(buffer)) {
+ interceptor = null;
+ }
+ }
+
+ public static interface Interceptor {
+
+ /**
+ * Handles data received from the remote end.
+ *
+ * @param data Buffer containing data.
+ * @return "true" if the interceptor expects more data, "false" to uninstall the interceptor.
+ */
+ boolean handle(ByteBuf data) throws Exception;
+
+ /** Called if an exception is thrown in the channel pipeline. */
+ void exceptionCaught(Throwable cause) throws Exception;
+
+ /** Called if the channel is closed and the interceptor is still installed. */
+ void channelInactive() throws Exception;
+
+ }
+
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index d500bc3c98..22b451fc0e 100644
--- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -39,6 +39,9 @@ import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamRequest;
+import org.apache.spark.network.protocol.StreamResponse;
import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.NettyUtils;
@@ -80,6 +83,7 @@ public class ProtocolSuite {
testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2)));
testClientToServer(new RpcRequest(12345, new byte[0]));
testClientToServer(new RpcRequest(12345, new byte[100]));
+ testClientToServer(new StreamRequest("abcde"));
}
@Test
@@ -92,6 +96,10 @@ public class ProtocolSuite {
testServerToClient(new RpcResponse(12345, new byte[1000]));
testServerToClient(new RpcFailure(0, "this is an error"));
testServerToClient(new RpcFailure(0, ""));
+ // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the
+ // channel and cannot be tested like this.
+ testServerToClient(new StreamResponse("anId", 12345L, new TestManagedBuffer(0)));
+ testServerToClient(new StreamFailure("anId", "this is an error"));
}
/**
diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
new file mode 100644
index 0000000000..6dcec831de
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.io.Files;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.StreamCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class StreamSuite {
+ private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "file" };
+
+ private static TransportServer server;
+ private static TransportClientFactory clientFactory;
+ private static File testFile;
+ private static File tempDir;
+
+ private static ByteBuffer smallBuffer;
+ private static ByteBuffer largeBuffer;
+
+ private static ByteBuffer createBuffer(int bufSize) {
+ ByteBuffer buf = ByteBuffer.allocate(bufSize);
+ for (int i = 0; i < bufSize; i ++) {
+ buf.put((byte) i);
+ }
+ buf.flip();
+ return buf;
+ }
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ tempDir = Files.createTempDir();
+ smallBuffer = createBuffer(100);
+ largeBuffer = createBuffer(100000);
+
+ testFile = File.createTempFile("stream-test-file", "txt", tempDir);
+ FileOutputStream fp = new FileOutputStream(testFile);
+ try {
+ Random rnd = new Random();
+ for (int i = 0; i < 512; i++) {
+ byte[] fileContent = new byte[1024];
+ rnd.nextBytes(fileContent);
+ fp.write(fileContent);
+ }
+ } finally {
+ fp.close();
+ }
+
+ final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+ final StreamManager streamManager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public ManagedBuffer openStream(String streamId) {
+ switch (streamId) {
+ case "largeBuffer":
+ return new NioManagedBuffer(largeBuffer);
+ case "smallBuffer":
+ return new NioManagedBuffer(smallBuffer);
+ case "file":
+ return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length());
+ default:
+ throw new IllegalArgumentException("Invalid stream: " + streamId);
+ }
+ }
+ };
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return streamManager;
+ }
+ };
+ TransportContext context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ server.close();
+ clientFactory.close();
+ if (tempDir != null) {
+ for (File f : tempDir.listFiles()) {
+ f.delete();
+ }
+ tempDir.delete();
+ }
+ }
+
+ @Test
+ public void testSingleStream() throws Throwable {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ StreamTask task = new StreamTask(client, "largeBuffer", TimeUnit.SECONDS.toMillis(5));
+ task.run();
+ task.check();
+ } finally {
+ client.close();
+ }
+ }
+
+ @Test
+ public void testMultipleStreams() throws Throwable {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ for (int i = 0; i < 20; i++) {
+ StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length],
+ TimeUnit.SECONDS.toMillis(5));
+ task.run();
+ task.check();
+ }
+ } finally {
+ client.close();
+ }
+ }
+
+ @Test
+ public void testConcurrentStreams() throws Throwable {
+ ExecutorService executor = Executors.newFixedThreadPool(20);
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+
+ try {
+ List<StreamTask> tasks = new ArrayList<>();
+ for (int i = 0; i < 20; i++) {
+ StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length],
+ TimeUnit.SECONDS.toMillis(20));
+ tasks.add(task);
+ executor.submit(task);
+ }
+
+ executor.shutdown();
+ assertTrue("Timed out waiting for tasks.", executor.awaitTermination(30, TimeUnit.SECONDS));
+ for (StreamTask task : tasks) {
+ task.check();
+ }
+ } finally {
+ executor.shutdownNow();
+ client.close();
+ }
+ }
+
+ private static class StreamTask implements Runnable {
+
+ private final TransportClient client;
+ private final String streamId;
+ private final long timeoutMs;
+ private Throwable error;
+
+ StreamTask(TransportClient client, String streamId, long timeoutMs) {
+ this.client = client;
+ this.streamId = streamId;
+ this.timeoutMs = timeoutMs;
+ }
+
+ @Override
+ public void run() {
+ ByteBuffer srcBuffer = null;
+ OutputStream out = null;
+ File outFile = null;
+ try {
+ ByteArrayOutputStream baos = null;
+
+ switch (streamId) {
+ case "largeBuffer":
+ baos = new ByteArrayOutputStream();
+ out = baos;
+ srcBuffer = largeBuffer;
+ break;
+ case "smallBuffer":
+ baos = new ByteArrayOutputStream();
+ out = baos;
+ srcBuffer = smallBuffer;
+ break;
+ case "file":
+ outFile = File.createTempFile("data", ".tmp", tempDir);
+ out = new FileOutputStream(outFile);
+ break;
+ default:
+ throw new IllegalArgumentException(streamId);
+ }
+
+ TestCallback callback = new TestCallback(out);
+ client.stream(streamId, callback);
+ waitForCompletion(callback);
+
+ if (srcBuffer == null) {
+ assertTrue("File stream did not match.", Files.equal(testFile, outFile));
+ } else {
+ ByteBuffer base;
+ synchronized (srcBuffer) {
+ base = srcBuffer.duplicate();
+ }
+ byte[] result = baos.toByteArray();
+ byte[] expected = new byte[base.remaining()];
+ base.get(expected);
+ assertEquals(expected.length, result.length);
+ assertTrue("buffers don't match", Arrays.equals(expected, result));
+ }
+ } catch (Throwable t) {
+ error = t;
+ } finally {
+ if (out != null) {
+ try {
+ out.close();
+ } catch (Exception e) {
+ // ignore.
+ }
+ }
+ if (outFile != null) {
+ outFile.delete();
+ }
+ }
+ }
+
+ public void check() throws Throwable {
+ if (error != null) {
+ throw error;
+ }
+ }
+
+ private void waitForCompletion(TestCallback callback) throws Exception {
+ long now = System.currentTimeMillis();
+ long deadline = now + timeoutMs;
+ synchronized (callback) {
+ while (!callback.completed && now < deadline) {
+ callback.wait(deadline - now);
+ now = System.currentTimeMillis();
+ }
+ }
+ assertTrue("Timed out waiting for stream.", callback.completed);
+ assertNull(callback.error);
+ }
+
+ }
+
+ private static class TestCallback implements StreamCallback {
+
+ private final OutputStream out;
+ public volatile boolean completed;
+ public volatile Throwable error;
+
+ TestCallback(OutputStream out) {
+ this.out = out;
+ this.completed = false;
+ }
+
+ @Override
+ public void onData(String streamId, ByteBuffer buf) throws IOException {
+ byte[] tmp = new byte[buf.remaining()];
+ buf.get(tmp);
+ out.write(tmp);
+ }
+
+ @Override
+ public void onComplete(String streamId) throws IOException {
+ out.close();
+ synchronized (this) {
+ completed = true;
+ notifyAll();
+ }
+ }
+
+ @Override
+ public void onFailure(String streamId, Throwable cause) {
+ error = cause;
+ synchronized (this) {
+ completed = true;
+ notifyAll();
+ }
+ }
+
+ }
+
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
new file mode 100644
index 0000000000..ca74f0a00c
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.nio.ByteBuffer;
+import java.util.Random;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import org.junit.Test;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+public class TransportFrameDecoderSuite {
+
+ @Test
+ public void testFrameDecoding() throws Exception {
+ Random rnd = new Random();
+ TransportFrameDecoder decoder = new TransportFrameDecoder();
+ ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+
+ final int frameCount = 100;
+ ByteBuf data = Unpooled.buffer();
+ try {
+ for (int i = 0; i < frameCount; i++) {
+ byte[] frame = new byte[1024 * (rnd.nextInt(31) + 1)];
+ data.writeLong(frame.length + 8);
+ data.writeBytes(frame);
+ }
+
+ while (data.isReadable()) {
+ int size = rnd.nextInt(16 * 1024) + 256;
+ decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)));
+ }
+
+ verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class));
+ } finally {
+ data.release();
+ }
+ }
+
+ @Test
+ public void testInterception() throws Exception {
+ final int interceptedReads = 3;
+ TransportFrameDecoder decoder = new TransportFrameDecoder();
+ TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads));
+ ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+
+ byte[] data = new byte[8];
+ ByteBuf len = Unpooled.copyLong(8 + data.length);
+ ByteBuf dataBuf = Unpooled.wrappedBuffer(data);
+
+ try {
+ decoder.setInterceptor(interceptor);
+ for (int i = 0; i < interceptedReads; i++) {
+ decoder.channelRead(ctx, dataBuf);
+ dataBuf.release();
+ dataBuf = Unpooled.wrappedBuffer(data);
+ }
+ decoder.channelRead(ctx, len);
+ decoder.channelRead(ctx, dataBuf);
+ verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class));
+ verify(ctx).fireChannelRead(any(ByteBuffer.class));
+ } finally {
+ len.release();
+ dataBuf.release();
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testNegativeFrameSize() throws Exception {
+ testInvalidFrame(-1);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testEmptyFrame() throws Exception {
+ // 8 because frame size includes the frame length.
+ testInvalidFrame(8);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testLargeFrame() throws Exception {
+ // Frame length includes the frame size field, so need to add a few more bytes.
+ testInvalidFrame(Integer.MAX_VALUE + 9);
+ }
+
+ private void testInvalidFrame(long size) throws Exception {
+ TransportFrameDecoder decoder = new TransportFrameDecoder();
+ ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+ ByteBuf frame = Unpooled.copyLong(size);
+ try {
+ decoder.channelRead(ctx, frame);
+ } finally {
+ frame.release();
+ }
+ }
+
+ private static class MockInterceptor implements TransportFrameDecoder.Interceptor {
+
+ private int remainingReads;
+
+ MockInterceptor(int readCount) {
+ this.remainingReads = readCount;
+ }
+
+ @Override
+ public boolean handle(ByteBuf data) throws Exception {
+ data.readerIndex(data.readerIndex() + data.readableBytes());
+ assertFalse(data.isReadable());
+ remainingReads -= 1;
+ return remainingReads != 0;
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) throws Exception {
+
+ }
+
+ @Override
+ public void channelInactive() throws Exception {
+
+ }
+
+ }
+
+}