aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--network/common/src/main/java/org/apache/spark/network/TransportContext.java3
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java40
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java76
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java41
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java47
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java16
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Message.java6
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java27
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java40
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java80
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java78
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java91
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/StreamManager.java13
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java20
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java154
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java8
-rw-r--r--network/common/src/test/java/org/apache/spark/network/StreamSuite.java325
-rw-r--r--network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java142
20 files changed, 1196 insertions, 29 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index b8d073fa16..43900e6f2c 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -39,6 +39,7 @@ import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.util.TransportFrameDecoder;
/**
* Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to
@@ -119,7 +120,7 @@ public class TransportContext {
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
channel.pipeline()
.addLast("encoder", encoder)
- .addLast("frameDecoder", NettyUtils.createFrameDecoder())
+ .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
.addLast("decoder", decoder)
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java
new file mode 100644
index 0000000000..093fada320
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+/**
+ * Callback for streaming data. Stream data will be offered to the {@link onData(ByteBuffer)}
+ * method as it arrives. Once all the stream data is received, {@link onComplete()} will be
+ * called.
+ * <p>
+ * The network library guarantees that a single thread will call these methods at a time, but
+ * different call may be made by different threads.
+ */
+public interface StreamCallback {
+ /** Called upon receipt of stream data. */
+ void onData(String streamId, ByteBuffer buf) throws IOException;
+
+ /** Called when all data from the stream has been received. */
+ void onComplete(String streamId) throws IOException;
+
+ /** Called if there's an error reading data from the stream. */
+ void onFailure(String streamId, Throwable cause) throws IOException;
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
new file mode 100644
index 0000000000..02230a00e6
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.nio.ByteBuffer;
+import java.nio.channels.ClosedChannelException;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.util.TransportFrameDecoder;
+
+/**
+ * An interceptor that is registered with the frame decoder to feed stream data to a
+ * callback.
+ */
+class StreamInterceptor implements TransportFrameDecoder.Interceptor {
+
+ private final String streamId;
+ private final long byteCount;
+ private final StreamCallback callback;
+
+ private volatile long bytesRead;
+
+ StreamInterceptor(String streamId, long byteCount, StreamCallback callback) {
+ this.streamId = streamId;
+ this.byteCount = byteCount;
+ this.callback = callback;
+ this.bytesRead = 0;
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) throws Exception {
+ callback.onFailure(streamId, cause);
+ }
+
+ @Override
+ public void channelInactive() throws Exception {
+ callback.onFailure(streamId, new ClosedChannelException());
+ }
+
+ @Override
+ public boolean handle(ByteBuf buf) throws Exception {
+ int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead);
+ ByteBuffer nioBuffer = buf.readSlice(toRead).nioBuffer();
+
+ int available = nioBuffer.remaining();
+ callback.onData(streamId, nioBuffer);
+ bytesRead += available;
+ if (bytesRead > byteCount) {
+ RuntimeException re = new IllegalStateException(String.format(
+ "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead));
+ callback.onFailure(streamId, re);
+ throw re;
+ } else if (bytesRead == byteCount) {
+ callback.onComplete(streamId);
+ }
+
+ return bytesRead != byteCount;
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index fbb8bb6b2f..a0ba223e34 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -38,6 +38,7 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.protocol.ChunkFetchRequest;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamRequest;
import org.apache.spark.network.util.NettyUtils;
/**
@@ -160,6 +161,46 @@ public class TransportClient implements Closeable {
}
/**
+ * Request to stream the data with the given stream ID from the remote end.
+ *
+ * @param streamId The stream to fetch.
+ * @param callback Object to call with the stream data.
+ */
+ public void stream(final String streamId, final StreamCallback callback) {
+ final String serverAddr = NettyUtils.getRemoteAddress(channel);
+ final long startTime = System.currentTimeMillis();
+ logger.debug("Sending stream request for {} to {}", streamId, serverAddr);
+
+ // Need to synchronize here so that the callback is added to the queue and the RPC is
+ // written to the socket atomically, so that callbacks are called in the right order
+ // when responses arrive.
+ synchronized (this) {
+ handler.addStreamCallback(callback);
+ channel.writeAndFlush(new StreamRequest(streamId)).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.trace("Sending request for {} to {} took {} ms", streamId, serverAddr,
+ timeTaken);
+ } else {
+ String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId,
+ serverAddr, future.cause());
+ logger.error(errorMsg, future.cause());
+ channel.close();
+ try {
+ callback.onFailure(streamId, new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
+ }
+ }
+ });
+ }
+ }
+
+ /**
* Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked
* with the server's response or upon any failure.
*/
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index 94fc21af5e..ed3f36af58 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -19,7 +19,9 @@ package org.apache.spark.network.client;
import java.io.IOException;
import java.util.Map;
+import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
import io.netty.channel.Channel;
@@ -32,8 +34,11 @@ import org.apache.spark.network.protocol.ResponseMessage;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamResponse;
import org.apache.spark.network.server.MessageHandler;
import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportFrameDecoder;
/**
* Handler that processes server responses, in response to requests issued from a
@@ -50,6 +55,8 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
private final Map<Long, RpcResponseCallback> outstandingRpcs;
+ private final Queue<StreamCallback> streamCallbacks;
+
/** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
private final AtomicLong timeOfLastRequestNs;
@@ -57,6 +64,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
this.channel = channel;
this.outstandingFetches = new ConcurrentHashMap<StreamChunkId, ChunkReceivedCallback>();
this.outstandingRpcs = new ConcurrentHashMap<Long, RpcResponseCallback>();
+ this.streamCallbacks = new ConcurrentLinkedQueue<StreamCallback>();
this.timeOfLastRequestNs = new AtomicLong(0);
}
@@ -78,6 +86,10 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
outstandingRpcs.remove(requestId);
}
+ public void addStreamCallback(StreamCallback callback) {
+ streamCallbacks.offer(callback);
+ }
+
/**
* Fire the failure callback for all outstanding requests. This is called when we have an
* uncaught exception or pre-mature connection termination.
@@ -124,11 +136,11 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
if (listener == null) {
logger.warn("Ignoring response for block {} from {} since it is not outstanding",
resp.streamChunkId, remoteAddress);
- resp.buffer.release();
+ resp.body.release();
} else {
outstandingFetches.remove(resp.streamChunkId);
- listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer);
- resp.buffer.release();
+ listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body);
+ resp.body.release();
}
} else if (message instanceof ChunkFetchFailure) {
ChunkFetchFailure resp = (ChunkFetchFailure) message;
@@ -161,6 +173,34 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
outstandingRpcs.remove(resp.requestId);
listener.onFailure(new RuntimeException(resp.errorString));
}
+ } else if (message instanceof StreamResponse) {
+ StreamResponse resp = (StreamResponse) message;
+ StreamCallback callback = streamCallbacks.poll();
+ if (callback != null) {
+ StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, resp.byteCount,
+ callback);
+ try {
+ TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
+ channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
+ frameDecoder.setInterceptor(interceptor);
+ } catch (Exception e) {
+ logger.error("Error installing stream handler.", e);
+ }
+ } else {
+ logger.error("Could not find callback for StreamResponse.");
+ }
+ } else if (message instanceof StreamFailure) {
+ StreamFailure resp = (StreamFailure) message;
+ StreamCallback callback = streamCallbacks.poll();
+ if (callback != null) {
+ try {
+ callback.onFailure(resp.streamId, new RuntimeException(resp.error));
+ } catch (IOException ioe) {
+ logger.warn("Error in stream failure handler.", ioe);
+ }
+ } else {
+ logger.warn("Stream failure with unknown callback: {}", resp.error);
+ }
} else {
throw new IllegalStateException("Unknown response type: " + message.type());
}
@@ -175,4 +215,5 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
public long getTimeOfLastRequestNs() {
return timeOfLastRequestNs.get();
}
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
index c962fb7ecf..e6a7e9a8b4 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
@@ -30,13 +30,12 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
* may be written by Netty in a more efficient manner (i.e., zero-copy write).
* Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer.
*/
-public final class ChunkFetchSuccess implements ResponseMessage {
+public final class ChunkFetchSuccess extends ResponseWithBody {
public final StreamChunkId streamChunkId;
- public final ManagedBuffer buffer;
public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
+ super(buffer, true);
this.streamChunkId = streamChunkId;
- this.buffer = buffer;
}
@Override
@@ -53,6 +52,11 @@ public final class ChunkFetchSuccess implements ResponseMessage {
streamChunkId.encode(buf);
}
+ @Override
+ public ResponseMessage createFailureResponse(String error) {
+ return new ChunkFetchFailure(streamChunkId, error);
+ }
+
/** Decoding uses the given ByteBuf as our data, and will retain() it. */
public static ChunkFetchSuccess decode(ByteBuf buf) {
StreamChunkId streamChunkId = StreamChunkId.decode(buf);
@@ -63,14 +67,14 @@ public final class ChunkFetchSuccess implements ResponseMessage {
@Override
public int hashCode() {
- return Objects.hashCode(streamChunkId, buffer);
+ return Objects.hashCode(streamChunkId, body);
}
@Override
public boolean equals(Object other) {
if (other instanceof ChunkFetchSuccess) {
ChunkFetchSuccess o = (ChunkFetchSuccess) other;
- return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer);
+ return streamChunkId.equals(o.streamChunkId) && body.equals(o.body);
}
return false;
}
@@ -79,7 +83,7 @@ public final class ChunkFetchSuccess implements ResponseMessage {
public String toString() {
return Objects.toStringHelper(this)
.add("streamChunkId", streamChunkId)
- .add("buffer", buffer)
+ .add("buffer", body)
.toString();
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
index d568370125..d01598c20f 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
@@ -27,7 +27,8 @@ public interface Message extends Encodable {
/** Preceding every serialized Message is its type, which allows us to deserialize it. */
public static enum Type implements Encodable {
ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
- RpcRequest(3), RpcResponse(4), RpcFailure(5);
+ RpcRequest(3), RpcResponse(4), RpcFailure(5),
+ StreamRequest(6), StreamResponse(7), StreamFailure(8);
private final byte id;
@@ -51,6 +52,9 @@ public interface Message extends Encodable {
case 3: return RpcRequest;
case 4: return RpcResponse;
case 5: return RpcFailure;
+ case 6: return StreamRequest;
+ case 7: return StreamResponse;
+ case 8: return StreamFailure;
default: throw new IllegalArgumentException("Unknown message type: " + id);
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
index 81f8d7f963..3c04048f38 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
@@ -63,6 +63,15 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
case RpcFailure:
return RpcFailure.decode(in);
+ case StreamRequest:
+ return StreamRequest.decode(in);
+
+ case StreamResponse:
+ return StreamResponse.decode(in);
+
+ case StreamFailure:
+ return StreamFailure.decode(in);
+
default:
throw new IllegalArgumentException("Unexpected message type: " + msgType);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 0f999f5dfe..6cce97c807 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -45,27 +45,32 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) {
Object body = null;
long bodyLength = 0;
+ boolean isBodyInFrame = false;
- // Only ChunkFetchSuccesses have data besides the header.
+ // Detect ResponseWithBody messages and get the data buffer out of them.
// The body is used in order to enable zero-copy transfer for the payload.
- if (in instanceof ChunkFetchSuccess) {
- ChunkFetchSuccess resp = (ChunkFetchSuccess) in;
+ if (in instanceof ResponseWithBody) {
+ ResponseWithBody resp = (ResponseWithBody) in;
try {
- bodyLength = resp.buffer.size();
- body = resp.buffer.convertToNetty();
+ bodyLength = resp.body.size();
+ body = resp.body.convertToNetty();
+ isBodyInFrame = resp.isBodyInFrame;
} catch (Exception e) {
- // Re-encode this message as BlockFetchFailure.
- logger.error(String.format("Error opening block %s for client %s",
- resp.streamChunkId, ctx.channel().remoteAddress()), e);
- encode(ctx, new ChunkFetchFailure(resp.streamChunkId, e.getMessage()), out);
+ // Re-encode this message as a failure response.
+ String error = e.getMessage() != null ? e.getMessage() : "null";
+ logger.error(String.format("Error processing %s for client %s",
+ resp, ctx.channel().remoteAddress()), e);
+ encode(ctx, resp.createFailureResponse(error), out);
return;
}
}
Message.Type msgType = in.type();
- // All messages have the frame length, message type, and message itself.
+ // All messages have the frame length, message type, and message itself. The frame length
+ // may optionally include the length of the body data, depending on what message is being
+ // sent.
int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
- long frameLength = headerLength + bodyLength;
+ long frameLength = headerLength + (isBodyInFrame ? bodyLength : 0);
ByteBuf header = ctx.alloc().heapBuffer(headerLength);
header.writeLong(frameLength);
msgType.encode(header);
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java
new file mode 100644
index 0000000000..67be77e39f
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Abstract class for response messages that contain a large data portion kept in a separate
+ * buffer. These messages are treated especially by MessageEncoder.
+ */
+public abstract class ResponseWithBody implements ResponseMessage {
+ public final ManagedBuffer body;
+ public final boolean isBodyInFrame;
+
+ protected ResponseWithBody(ManagedBuffer body, boolean isBodyInFrame) {
+ this.body = body;
+ this.isBodyInFrame = isBodyInFrame;
+ }
+
+ public abstract ResponseMessage createFailureResponse(String error);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
new file mode 100644
index 0000000000..e3dade2ebf
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Message indicating an error when transferring a stream.
+ */
+public final class StreamFailure implements ResponseMessage {
+ public final String streamId;
+ public final String error;
+
+ public StreamFailure(String streamId, String error) {
+ this.streamId = streamId;
+ this.error = error;
+ }
+
+ @Override
+ public Type type() { return Type.StreamFailure; }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(streamId) + Encoders.Strings.encodedLength(error);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, streamId);
+ Encoders.Strings.encode(buf, error);
+ }
+
+ public static StreamFailure decode(ByteBuf buf) {
+ String streamId = Encoders.Strings.decode(buf);
+ String error = Encoders.Strings.decode(buf);
+ return new StreamFailure(streamId, error);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId, error);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamFailure) {
+ StreamFailure o = (StreamFailure) other;
+ return streamId.equals(o.streamId) && error.equals(o.error);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("error", error)
+ .toString();
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
new file mode 100644
index 0000000000..821e8f5388
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Request to stream data from the remote end.
+ * <p>
+ * The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before
+ * the data can be streamed.
+ */
+public final class StreamRequest implements RequestMessage {
+ public final String streamId;
+
+ public StreamRequest(String streamId) {
+ this.streamId = streamId;
+ }
+
+ @Override
+ public Type type() { return Type.StreamRequest; }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(streamId);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, streamId);
+ }
+
+ public static StreamRequest decode(ByteBuf buf) {
+ String streamId = Encoders.Strings.decode(buf);
+ return new StreamRequest(streamId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamRequest) {
+ StreamRequest o = (StreamRequest) other;
+ return streamId.equals(o.streamId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .toString();
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
new file mode 100644
index 0000000000..ac5ab9a323
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Response to {@link StreamRequest} when the stream has been successfully opened.
+ * <p>
+ * Note the message itself does not contain the stream data. That is written separately by the
+ * sender. The receiver is expected to set a temporary channel handler that will consume the
+ * number of bytes this message says the stream has.
+ */
+public final class StreamResponse extends ResponseWithBody {
+ public final String streamId;
+ public final long byteCount;
+
+ public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) {
+ super(buffer, false);
+ this.streamId = streamId;
+ this.byteCount = byteCount;
+ }
+
+ @Override
+ public Type type() { return Type.StreamResponse; }
+
+ @Override
+ public int encodedLength() {
+ return 8 + Encoders.Strings.encodedLength(streamId);
+ }
+
+ /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, streamId);
+ buf.writeLong(byteCount);
+ }
+
+ @Override
+ public ResponseMessage createFailureResponse(String error) {
+ return new StreamFailure(streamId, error);
+ }
+
+ public static StreamResponse decode(ByteBuf buf) {
+ String streamId = Encoders.Strings.decode(buf);
+ long byteCount = buf.readLong();
+ return new StreamResponse(streamId, byteCount, null);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(byteCount, streamId);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamResponse) {
+ StreamResponse o = (StreamResponse) other;
+ return byteCount == o.byteCount && streamId.equals(o.streamId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("byteCount", byteCount)
+ .toString();
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
index aaa677c965..3f0155957a 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -47,6 +47,19 @@ public abstract class StreamManager {
public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);
/**
+ * Called in response to a stream() request. The returned data is streamed to the client
+ * through a single TCP connection.
+ *
+ * Note the <code>streamId</code> argument is not related to the similarly named argument in the
+ * {@link #getChunk(long, int)} method.
+ *
+ * @param streamId id of a stream that has been previously registered with the StreamManager.
+ */
+ public ManagedBuffer openStream(String streamId) {
+ throw new UnsupportedOperationException();
+ }
+
+ /**
* Associates a stream with a single client connection, which is guaranteed to be the only reader
* of the stream. The getChunk() method will be called serially on this connection and once the
* connection is closed, the stream will never be used again, enabling cleanup.
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 9b8b047b49..4f67bd573b 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -35,6 +35,9 @@ import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamRequest;
+import org.apache.spark.network.protocol.StreamResponse;
import org.apache.spark.network.util.NettyUtils;
/**
@@ -92,6 +95,8 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
processFetchRequest((ChunkFetchRequest) request);
} else if (request instanceof RpcRequest) {
processRpcRequest((RpcRequest) request);
+ } else if (request instanceof StreamRequest) {
+ processStreamRequest((StreamRequest) request);
} else {
throw new IllegalArgumentException("Unknown request type: " + request);
}
@@ -117,6 +122,21 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
respond(new ChunkFetchSuccess(req.streamChunkId, buf));
}
+ private void processStreamRequest(final StreamRequest req) {
+ final String client = NettyUtils.getRemoteAddress(channel);
+ ManagedBuffer buf;
+ try {
+ buf = streamManager.openStream(req.streamId);
+ } catch (Exception e) {
+ logger.error(String.format(
+ "Error opening stream %s for request from %s", req.streamId, client), e);
+ respond(new StreamFailure(req.streamId, Throwables.getStackTraceAsString(e)));
+ return;
+ }
+
+ respond(new StreamResponse(req.streamId, buf.size(), buf));
+ }
+
private void processRpcRequest(final RpcRequest req) {
try {
rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() {
diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
index 26c6399ce7..caa7260bc8 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
@@ -89,13 +89,8 @@ public class NettyUtils {
* Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame.
* This is used before all decoders.
*/
- public static ByteToMessageDecoder createFrameDecoder() {
- // maxFrameLength = 2G
- // lengthFieldOffset = 0
- // lengthFieldLength = 8
- // lengthAdjustment = -8, i.e. exclude the 8 byte length itself
- // initialBytesToStrip = 8, i.e. strip out the length field itself
- return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8);
+ public static TransportFrameDecoder createFrameDecoder() {
+ return new TransportFrameDecoder();
}
/** Returns the remote address on the channel or "&lt;unknown remote&gt;" if none exists. */
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
new file mode 100644
index 0000000000..272ea84e61
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+
+/**
+ * A customized frame decoder that allows intercepting raw data.
+ * <p>
+ * This behaves like Netty's frame decoder (with harcoded parameters that match this library's
+ * needs), except it allows an interceptor to be installed to read data directly before it's
+ * framed.
+ * <p>
+ * Unlike Netty's frame decoder, each frame is dispatched to child handlers as soon as it's
+ * decoded, instead of building as many frames as the current buffer allows and dispatching
+ * all of them. This allows a child handler to install an interceptor if needed.
+ * <p>
+ * If an interceptor is installed, framing stops, and data is instead fed directly to the
+ * interceptor. When the interceptor indicates that it doesn't need to read any more data,
+ * framing resumes. Interceptors should not hold references to the data buffers provided
+ * to their handle() method.
+ */
+public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
+
+ public static final String HANDLER_NAME = "frameDecoder";
+ private static final int LENGTH_SIZE = 8;
+ private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE;
+
+ private CompositeByteBuf buffer;
+ private volatile Interceptor interceptor;
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
+ ByteBuf in = (ByteBuf) data;
+
+ if (buffer == null) {
+ buffer = in.alloc().compositeBuffer();
+ }
+
+ buffer.writeBytes(in);
+
+ while (buffer.isReadable()) {
+ feedInterceptor();
+ if (interceptor != null) {
+ continue;
+ }
+
+ ByteBuf frame = decodeNext();
+ if (frame != null) {
+ ctx.fireChannelRead(frame);
+ } else {
+ break;
+ }
+ }
+
+ // We can't discard read sub-buffers if there are other references to the buffer (e.g.
+ // through slices used for framing). This assumes that code that retains references
+ // will call retain() from the thread that called "fireChannelRead()" above, otherwise
+ // ref counting will go awry.
+ if (buffer != null && buffer.refCnt() == 1) {
+ buffer.discardReadComponents();
+ }
+ }
+
+ protected ByteBuf decodeNext() throws Exception {
+ if (buffer.readableBytes() < LENGTH_SIZE) {
+ return null;
+ }
+
+ int frameLen = (int) buffer.readLong() - LENGTH_SIZE;
+ if (buffer.readableBytes() < frameLen) {
+ buffer.readerIndex(buffer.readerIndex() - LENGTH_SIZE);
+ return null;
+ }
+
+ Preconditions.checkArgument(frameLen < MAX_FRAME_SIZE, "Too large frame: %s", frameLen);
+ Preconditions.checkArgument(frameLen > 0, "Frame length should be positive: %s", frameLen);
+
+ ByteBuf frame = buffer.readSlice(frameLen);
+ frame.retain();
+ return frame;
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ if (buffer != null) {
+ if (buffer.isReadable()) {
+ feedInterceptor();
+ }
+ buffer.release();
+ }
+ if (interceptor != null) {
+ interceptor.channelInactive();
+ }
+ super.channelInactive(ctx);
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
+ if (interceptor != null) {
+ interceptor.exceptionCaught(cause);
+ }
+ super.exceptionCaught(ctx, cause);
+ }
+
+ public void setInterceptor(Interceptor interceptor) {
+ Preconditions.checkState(this.interceptor == null, "Already have an interceptor.");
+ this.interceptor = interceptor;
+ }
+
+ private void feedInterceptor() throws Exception {
+ if (interceptor != null && !interceptor.handle(buffer)) {
+ interceptor = null;
+ }
+ }
+
+ public static interface Interceptor {
+
+ /**
+ * Handles data received from the remote end.
+ *
+ * @param data Buffer containing data.
+ * @return "true" if the interceptor expects more data, "false" to uninstall the interceptor.
+ */
+ boolean handle(ByteBuf data) throws Exception;
+
+ /** Called if an exception is thrown in the channel pipeline. */
+ void exceptionCaught(Throwable cause) throws Exception;
+
+ /** Called if the channel is closed and the interceptor is still installed. */
+ void channelInactive() throws Exception;
+
+ }
+
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index d500bc3c98..22b451fc0e 100644
--- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -39,6 +39,9 @@ import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamRequest;
+import org.apache.spark.network.protocol.StreamResponse;
import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.NettyUtils;
@@ -80,6 +83,7 @@ public class ProtocolSuite {
testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2)));
testClientToServer(new RpcRequest(12345, new byte[0]));
testClientToServer(new RpcRequest(12345, new byte[100]));
+ testClientToServer(new StreamRequest("abcde"));
}
@Test
@@ -92,6 +96,10 @@ public class ProtocolSuite {
testServerToClient(new RpcResponse(12345, new byte[1000]));
testServerToClient(new RpcFailure(0, "this is an error"));
testServerToClient(new RpcFailure(0, ""));
+ // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the
+ // channel and cannot be tested like this.
+ testServerToClient(new StreamResponse("anId", 12345L, new TestManagedBuffer(0)));
+ testServerToClient(new StreamFailure("anId", "this is an error"));
}
/**
diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
new file mode 100644
index 0000000000..6dcec831de
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.io.Files;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.StreamCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class StreamSuite {
+ private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "file" };
+
+ private static TransportServer server;
+ private static TransportClientFactory clientFactory;
+ private static File testFile;
+ private static File tempDir;
+
+ private static ByteBuffer smallBuffer;
+ private static ByteBuffer largeBuffer;
+
+ private static ByteBuffer createBuffer(int bufSize) {
+ ByteBuffer buf = ByteBuffer.allocate(bufSize);
+ for (int i = 0; i < bufSize; i ++) {
+ buf.put((byte) i);
+ }
+ buf.flip();
+ return buf;
+ }
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ tempDir = Files.createTempDir();
+ smallBuffer = createBuffer(100);
+ largeBuffer = createBuffer(100000);
+
+ testFile = File.createTempFile("stream-test-file", "txt", tempDir);
+ FileOutputStream fp = new FileOutputStream(testFile);
+ try {
+ Random rnd = new Random();
+ for (int i = 0; i < 512; i++) {
+ byte[] fileContent = new byte[1024];
+ rnd.nextBytes(fileContent);
+ fp.write(fileContent);
+ }
+ } finally {
+ fp.close();
+ }
+
+ final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+ final StreamManager streamManager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public ManagedBuffer openStream(String streamId) {
+ switch (streamId) {
+ case "largeBuffer":
+ return new NioManagedBuffer(largeBuffer);
+ case "smallBuffer":
+ return new NioManagedBuffer(smallBuffer);
+ case "file":
+ return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length());
+ default:
+ throw new IllegalArgumentException("Invalid stream: " + streamId);
+ }
+ }
+ };
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return streamManager;
+ }
+ };
+ TransportContext context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ server.close();
+ clientFactory.close();
+ if (tempDir != null) {
+ for (File f : tempDir.listFiles()) {
+ f.delete();
+ }
+ tempDir.delete();
+ }
+ }
+
+ @Test
+ public void testSingleStream() throws Throwable {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ StreamTask task = new StreamTask(client, "largeBuffer", TimeUnit.SECONDS.toMillis(5));
+ task.run();
+ task.check();
+ } finally {
+ client.close();
+ }
+ }
+
+ @Test
+ public void testMultipleStreams() throws Throwable {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ for (int i = 0; i < 20; i++) {
+ StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length],
+ TimeUnit.SECONDS.toMillis(5));
+ task.run();
+ task.check();
+ }
+ } finally {
+ client.close();
+ }
+ }
+
+ @Test
+ public void testConcurrentStreams() throws Throwable {
+ ExecutorService executor = Executors.newFixedThreadPool(20);
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+
+ try {
+ List<StreamTask> tasks = new ArrayList<>();
+ for (int i = 0; i < 20; i++) {
+ StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length],
+ TimeUnit.SECONDS.toMillis(20));
+ tasks.add(task);
+ executor.submit(task);
+ }
+
+ executor.shutdown();
+ assertTrue("Timed out waiting for tasks.", executor.awaitTermination(30, TimeUnit.SECONDS));
+ for (StreamTask task : tasks) {
+ task.check();
+ }
+ } finally {
+ executor.shutdownNow();
+ client.close();
+ }
+ }
+
+ private static class StreamTask implements Runnable {
+
+ private final TransportClient client;
+ private final String streamId;
+ private final long timeoutMs;
+ private Throwable error;
+
+ StreamTask(TransportClient client, String streamId, long timeoutMs) {
+ this.client = client;
+ this.streamId = streamId;
+ this.timeoutMs = timeoutMs;
+ }
+
+ @Override
+ public void run() {
+ ByteBuffer srcBuffer = null;
+ OutputStream out = null;
+ File outFile = null;
+ try {
+ ByteArrayOutputStream baos = null;
+
+ switch (streamId) {
+ case "largeBuffer":
+ baos = new ByteArrayOutputStream();
+ out = baos;
+ srcBuffer = largeBuffer;
+ break;
+ case "smallBuffer":
+ baos = new ByteArrayOutputStream();
+ out = baos;
+ srcBuffer = smallBuffer;
+ break;
+ case "file":
+ outFile = File.createTempFile("data", ".tmp", tempDir);
+ out = new FileOutputStream(outFile);
+ break;
+ default:
+ throw new IllegalArgumentException(streamId);
+ }
+
+ TestCallback callback = new TestCallback(out);
+ client.stream(streamId, callback);
+ waitForCompletion(callback);
+
+ if (srcBuffer == null) {
+ assertTrue("File stream did not match.", Files.equal(testFile, outFile));
+ } else {
+ ByteBuffer base;
+ synchronized (srcBuffer) {
+ base = srcBuffer.duplicate();
+ }
+ byte[] result = baos.toByteArray();
+ byte[] expected = new byte[base.remaining()];
+ base.get(expected);
+ assertEquals(expected.length, result.length);
+ assertTrue("buffers don't match", Arrays.equals(expected, result));
+ }
+ } catch (Throwable t) {
+ error = t;
+ } finally {
+ if (out != null) {
+ try {
+ out.close();
+ } catch (Exception e) {
+ // ignore.
+ }
+ }
+ if (outFile != null) {
+ outFile.delete();
+ }
+ }
+ }
+
+ public void check() throws Throwable {
+ if (error != null) {
+ throw error;
+ }
+ }
+
+ private void waitForCompletion(TestCallback callback) throws Exception {
+ long now = System.currentTimeMillis();
+ long deadline = now + timeoutMs;
+ synchronized (callback) {
+ while (!callback.completed && now < deadline) {
+ callback.wait(deadline - now);
+ now = System.currentTimeMillis();
+ }
+ }
+ assertTrue("Timed out waiting for stream.", callback.completed);
+ assertNull(callback.error);
+ }
+
+ }
+
+ private static class TestCallback implements StreamCallback {
+
+ private final OutputStream out;
+ public volatile boolean completed;
+ public volatile Throwable error;
+
+ TestCallback(OutputStream out) {
+ this.out = out;
+ this.completed = false;
+ }
+
+ @Override
+ public void onData(String streamId, ByteBuffer buf) throws IOException {
+ byte[] tmp = new byte[buf.remaining()];
+ buf.get(tmp);
+ out.write(tmp);
+ }
+
+ @Override
+ public void onComplete(String streamId) throws IOException {
+ out.close();
+ synchronized (this) {
+ completed = true;
+ notifyAll();
+ }
+ }
+
+ @Override
+ public void onFailure(String streamId, Throwable cause) {
+ error = cause;
+ synchronized (this) {
+ completed = true;
+ notifyAll();
+ }
+ }
+
+ }
+
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
new file mode 100644
index 0000000000..ca74f0a00c
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.nio.ByteBuffer;
+import java.util.Random;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import org.junit.Test;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+public class TransportFrameDecoderSuite {
+
+ @Test
+ public void testFrameDecoding() throws Exception {
+ Random rnd = new Random();
+ TransportFrameDecoder decoder = new TransportFrameDecoder();
+ ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+
+ final int frameCount = 100;
+ ByteBuf data = Unpooled.buffer();
+ try {
+ for (int i = 0; i < frameCount; i++) {
+ byte[] frame = new byte[1024 * (rnd.nextInt(31) + 1)];
+ data.writeLong(frame.length + 8);
+ data.writeBytes(frame);
+ }
+
+ while (data.isReadable()) {
+ int size = rnd.nextInt(16 * 1024) + 256;
+ decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)));
+ }
+
+ verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class));
+ } finally {
+ data.release();
+ }
+ }
+
+ @Test
+ public void testInterception() throws Exception {
+ final int interceptedReads = 3;
+ TransportFrameDecoder decoder = new TransportFrameDecoder();
+ TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads));
+ ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+
+ byte[] data = new byte[8];
+ ByteBuf len = Unpooled.copyLong(8 + data.length);
+ ByteBuf dataBuf = Unpooled.wrappedBuffer(data);
+
+ try {
+ decoder.setInterceptor(interceptor);
+ for (int i = 0; i < interceptedReads; i++) {
+ decoder.channelRead(ctx, dataBuf);
+ dataBuf.release();
+ dataBuf = Unpooled.wrappedBuffer(data);
+ }
+ decoder.channelRead(ctx, len);
+ decoder.channelRead(ctx, dataBuf);
+ verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class));
+ verify(ctx).fireChannelRead(any(ByteBuffer.class));
+ } finally {
+ len.release();
+ dataBuf.release();
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testNegativeFrameSize() throws Exception {
+ testInvalidFrame(-1);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testEmptyFrame() throws Exception {
+ // 8 because frame size includes the frame length.
+ testInvalidFrame(8);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testLargeFrame() throws Exception {
+ // Frame length includes the frame size field, so need to add a few more bytes.
+ testInvalidFrame(Integer.MAX_VALUE + 9);
+ }
+
+ private void testInvalidFrame(long size) throws Exception {
+ TransportFrameDecoder decoder = new TransportFrameDecoder();
+ ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+ ByteBuf frame = Unpooled.copyLong(size);
+ try {
+ decoder.channelRead(ctx, frame);
+ } finally {
+ frame.release();
+ }
+ }
+
+ private static class MockInterceptor implements TransportFrameDecoder.Interceptor {
+
+ private int remainingReads;
+
+ MockInterceptor(int readCount) {
+ this.remainingReads = readCount;
+ }
+
+ @Override
+ public boolean handle(ByteBuf data) throws Exception {
+ data.readerIndex(data.readerIndex() + data.readableBytes());
+ assertFalse(data.isReadable());
+ remainingReads -= 1;
+ return remainingReads != 0;
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) throws Exception {
+
+ }
+
+ @Override
+ public void channelInactive() throws Exception {
+
+ }
+
+ }
+
+}