aboutsummaryrefslogtreecommitdiff
path: root/network
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2015-11-30 17:22:05 -0800
committerAndrew Or <andrew@databricks.com>2015-11-30 17:22:05 -0800
commit9bf2120672ae0f620a217ccd96bef189ab75e0d6 (patch)
tree6ad03e91cd0c3679e13e41274cdac07c002432b7 /network
parent0a46e4377216a1f7de478f220c3b3042a77789e2 (diff)
downloadspark-9bf2120672ae0f620a217ccd96bef189ab75e0d6.tar.gz
spark-9bf2120672ae0f620a217ccd96bef189ab75e0d6.tar.bz2
spark-9bf2120672ae0f620a217ccd96bef189ab75e0d6.zip
[SPARK-12007][NETWORK] Avoid copies in the network lib's RPC layer.
This change seems large, but most of it is just replacing `byte[]` with `ByteBuffer` and `new byte[]` with `ByteBuffer.allocate()`, since it changes the network library's API. The following are parts of the code that actually have meaningful changes: - The Message implementations were changed to inherit from a new AbstractMessage that can optionally hold a reference to a body (in the form of a ManagedBuffer); this is similar to how ResponseWithBody worked before, except now it's not restricted to just responses. - The TransportFrameDecoder was pretty much rewritten to avoid copies as much as possible; it doesn't rely on CompositeByteBuf to accumulate incoming data anymore, since CompositeByteBuf has issues when slices are retained. The code now is able to create frames without having to resort to copying bytes except for a few bytes (containing the frame length) in very rare cases. - Some minor changes in the SASL layer to convert things back to `byte[]` since the JDK SASL API operates on those. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #9987 from vanzin/SPARK-12007.
Diffstat (limited to 'network')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java4
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java16
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java16
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java54
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java (renamed from network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java)16
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java8
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Message.java11
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java29
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java33
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java34
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java39
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java19
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java12
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java31
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java26
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java8
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java8
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java15
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java48
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java142
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java5
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java10
-rw-r--r--network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java51
-rw-r--r--network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java26
-rw-r--r--network/common/src/test/java/org/apache/spark/network/StreamSuite.java5
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java24
-rw-r--r--network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java43
-rw-r--r--network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java23
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java9
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java3
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java7
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java5
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java8
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java18
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java2
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java21
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java8
44 files changed, 565 insertions, 286 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
index 6ec960d795..47e93f9846 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
@@ -17,13 +17,15 @@
package org.apache.spark.network.client;
+import java.nio.ByteBuffer;
+
/**
* Callback for the result of a single RPC. This will be invoked once with either success or
* failure.
*/
public interface RpcResponseCallback {
/** Successful serialized result from server. */
- void onSuccess(byte[] response);
+ void onSuccess(ByteBuffer response);
/** Exception either propagated from server or raised on client side. */
void onFailure(Throwable e);
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index 8a58e7b245..c49ca4d5ee 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -20,6 +20,7 @@ package org.apache.spark.network.client;
import java.io.Closeable;
import java.io.IOException;
import java.net.SocketAddress;
+import java.nio.ByteBuffer;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
@@ -36,6 +37,7 @@ import io.netty.channel.ChannelFutureListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.protocol.ChunkFetchRequest;
import org.apache.spark.network.protocol.OneWayMessage;
import org.apache.spark.network.protocol.RpcRequest;
@@ -212,7 +214,7 @@ public class TransportClient implements Closeable {
* @param callback Callback to handle the RPC's reply.
* @return The RPC's id.
*/
- public long sendRpc(byte[] message, final RpcResponseCallback callback) {
+ public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) {
final String serverAddr = NettyUtils.getRemoteAddress(channel);
final long startTime = System.currentTimeMillis();
logger.trace("Sending RPC to {}", serverAddr);
@@ -220,7 +222,7 @@ public class TransportClient implements Closeable {
final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
handler.addRpcRequest(requestId, callback);
- channel.writeAndFlush(new RpcRequest(requestId, message)).addListener(
+ channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener(
new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
@@ -249,12 +251,12 @@ public class TransportClient implements Closeable {
* Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
* a specified timeout for a response.
*/
- public byte[] sendRpcSync(byte[] message, long timeoutMs) {
- final SettableFuture<byte[]> result = SettableFuture.create();
+ public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) {
+ final SettableFuture<ByteBuffer> result = SettableFuture.create();
sendRpc(message, new RpcResponseCallback() {
@Override
- public void onSuccess(byte[] response) {
+ public void onSuccess(ByteBuffer response) {
result.set(response);
}
@@ -279,8 +281,8 @@ public class TransportClient implements Closeable {
*
* @param message The message to send.
*/
- public void send(byte[] message) {
- channel.writeAndFlush(new OneWayMessage(message));
+ public void send(ByteBuffer message) {
+ channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message)));
}
/**
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index 4c15045363..23a8dba593 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -136,7 +136,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
}
@Override
- public void handle(ResponseMessage message) {
+ public void handle(ResponseMessage message) throws Exception {
String remoteAddress = NettyUtils.getRemoteAddress(channel);
if (message instanceof ChunkFetchSuccess) {
ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
@@ -144,11 +144,11 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
if (listener == null) {
logger.warn("Ignoring response for block {} from {} since it is not outstanding",
resp.streamChunkId, remoteAddress);
- resp.body.release();
+ resp.body().release();
} else {
outstandingFetches.remove(resp.streamChunkId);
- listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body);
- resp.body.release();
+ listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
+ resp.body().release();
}
} else if (message instanceof ChunkFetchFailure) {
ChunkFetchFailure resp = (ChunkFetchFailure) message;
@@ -166,10 +166,14 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
if (listener == null) {
logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
- resp.requestId, remoteAddress, resp.response.length);
+ resp.requestId, remoteAddress, resp.body().size());
} else {
outstandingRpcs.remove(resp.requestId);
- listener.onSuccess(resp.response);
+ try {
+ listener.onSuccess(resp.body().nioByteBuffer());
+ } finally {
+ resp.body().release();
+ }
}
} else if (message instanceof RpcFailure) {
RpcFailure resp = (RpcFailure) message;
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java
new file mode 100644
index 0000000000..2924218c2f
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Abstract class for messages which optionally contain a body kept in a separate buffer.
+ */
+public abstract class AbstractMessage implements Message {
+ private final ManagedBuffer body;
+ private final boolean isBodyInFrame;
+
+ protected AbstractMessage() {
+ this(null, false);
+ }
+
+ protected AbstractMessage(ManagedBuffer body, boolean isBodyInFrame) {
+ this.body = body;
+ this.isBodyInFrame = isBodyInFrame;
+ }
+
+ @Override
+ public ManagedBuffer body() {
+ return body;
+ }
+
+ @Override
+ public boolean isBodyInFrame() {
+ return isBodyInFrame;
+ }
+
+ protected boolean equals(AbstractMessage other) {
+ return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body);
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java
index 67be77e39f..c362c92fc4 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java
@@ -17,23 +17,15 @@
package org.apache.spark.network.protocol;
-import com.google.common.base.Objects;
-import io.netty.buffer.ByteBuf;
-
import org.apache.spark.network.buffer.ManagedBuffer;
-import org.apache.spark.network.buffer.NettyManagedBuffer;
/**
- * Abstract class for response messages that contain a large data portion kept in a separate
- * buffer. These messages are treated especially by MessageEncoder.
+ * Abstract class for response messages.
*/
-public abstract class ResponseWithBody implements ResponseMessage {
- public final ManagedBuffer body;
- public final boolean isBodyInFrame;
+public abstract class AbstractResponseMessage extends AbstractMessage implements ResponseMessage {
- protected ResponseWithBody(ManagedBuffer body, boolean isBodyInFrame) {
- this.body = body;
- this.isBodyInFrame = isBodyInFrame;
+ protected AbstractResponseMessage(ManagedBuffer body, boolean isBodyInFrame) {
+ super(body, isBodyInFrame);
}
public abstract ResponseMessage createFailureResponse(String error);
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
index f0363830b6..7b28a9a969 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -23,7 +23,7 @@ import io.netty.buffer.ByteBuf;
/**
* Response to {@link ChunkFetchRequest} when there is an error fetching the chunk.
*/
-public final class ChunkFetchFailure implements ResponseMessage {
+public final class ChunkFetchFailure extends AbstractMessage implements ResponseMessage {
public final StreamChunkId streamChunkId;
public final String errorString;
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
index 5a173af54f..26d063feb5 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
@@ -24,7 +24,7 @@ import io.netty.buffer.ByteBuf;
* Request to fetch a sequence of a single chunk of a stream. This will correspond to a single
* {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
*/
-public final class ChunkFetchRequest implements RequestMessage {
+public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage {
public final StreamChunkId streamChunkId;
public ChunkFetchRequest(StreamChunkId streamChunkId) {
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
index e6a7e9a8b4..94c2ac9b20 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
@@ -30,7 +30,7 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
* may be written by Netty in a more efficient manner (i.e., zero-copy write).
* Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer.
*/
-public final class ChunkFetchSuccess extends ResponseWithBody {
+public final class ChunkFetchSuccess extends AbstractResponseMessage {
public final StreamChunkId streamChunkId;
public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
@@ -67,14 +67,14 @@ public final class ChunkFetchSuccess extends ResponseWithBody {
@Override
public int hashCode() {
- return Objects.hashCode(streamChunkId, body);
+ return Objects.hashCode(streamChunkId, body());
}
@Override
public boolean equals(Object other) {
if (other instanceof ChunkFetchSuccess) {
ChunkFetchSuccess o = (ChunkFetchSuccess) other;
- return streamChunkId.equals(o.streamChunkId) && body.equals(o.body);
+ return streamChunkId.equals(o.streamChunkId) && super.equals(o);
}
return false;
}
@@ -83,7 +83,7 @@ public final class ChunkFetchSuccess extends ResponseWithBody {
public String toString() {
return Objects.toStringHelper(this)
.add("streamChunkId", streamChunkId)
- .add("buffer", body)
+ .add("buffer", body())
.toString();
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
index 39afd03db6..66f5b8b3a5 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
@@ -19,17 +19,25 @@ package org.apache.spark.network.protocol;
import io.netty.buffer.ByteBuf;
+import org.apache.spark.network.buffer.ManagedBuffer;
+
/** An on-the-wire transmittable message. */
public interface Message extends Encodable {
/** Used to identify this request type. */
Type type();
+ /** An optional body for the message. */
+ ManagedBuffer body();
+
+ /** Whether to include the body of the message in the same frame as the message. */
+ boolean isBodyInFrame();
+
/** Preceding every serialized Message is its type, which allows us to deserialize it. */
public static enum Type implements Encodable {
ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
RpcRequest(3), RpcResponse(4), RpcFailure(5),
StreamRequest(6), StreamResponse(7), StreamFailure(8),
- OneWayMessage(9);
+ OneWayMessage(9), User(-1);
private final byte id;
@@ -57,6 +65,7 @@ public interface Message extends Encodable {
case 7: return StreamResponse;
case 8: return StreamFailure;
case 9: return OneWayMessage;
+ case -1: throw new IllegalArgumentException("User type messages cannot be decoded.");
default: throw new IllegalArgumentException("Unknown message type: " + id);
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 6cce97c807..abca22347b 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -42,25 +42,28 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
* data to 'out', in order to enable zero-copy transfer.
*/
@Override
- public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) {
+ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) throws Exception {
Object body = null;
long bodyLength = 0;
boolean isBodyInFrame = false;
- // Detect ResponseWithBody messages and get the data buffer out of them.
- // The body is used in order to enable zero-copy transfer for the payload.
- if (in instanceof ResponseWithBody) {
- ResponseWithBody resp = (ResponseWithBody) in;
+ // If the message has a body, take it out to enable zero-copy transfer for the payload.
+ if (in.body() != null) {
try {
- bodyLength = resp.body.size();
- body = resp.body.convertToNetty();
- isBodyInFrame = resp.isBodyInFrame;
+ bodyLength = in.body().size();
+ body = in.body().convertToNetty();
+ isBodyInFrame = in.isBodyInFrame();
} catch (Exception e) {
- // Re-encode this message as a failure response.
- String error = e.getMessage() != null ? e.getMessage() : "null";
- logger.error(String.format("Error processing %s for client %s",
- resp, ctx.channel().remoteAddress()), e);
- encode(ctx, resp.createFailureResponse(error), out);
+ if (in instanceof AbstractResponseMessage) {
+ AbstractResponseMessage resp = (AbstractResponseMessage) in;
+ // Re-encode this message as a failure response.
+ String error = e.getMessage() != null ? e.getMessage() : "null";
+ logger.error(String.format("Error processing %s for client %s",
+ in, ctx.channel().remoteAddress()), e);
+ encode(ctx, resp.createFailureResponse(error), out);
+ } else {
+ throw e;
+ }
return;
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
index 95a0270be3..efe0470f35 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
@@ -17,21 +17,21 @@
package org.apache.spark.network.protocol;
-import java.util.Arrays;
-
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
/**
* A RPC that does not expect a reply, which is handled by a remote
* {@link org.apache.spark.network.server.RpcHandler}.
*/
-public final class OneWayMessage implements RequestMessage {
- /** Serialized message to send to remote RpcHandler. */
- public final byte[] message;
+public final class OneWayMessage extends AbstractMessage implements RequestMessage {
- public OneWayMessage(byte[] message) {
- this.message = message;
+ public OneWayMessage(ManagedBuffer body) {
+ super(body, true);
}
@Override
@@ -39,29 +39,34 @@ public final class OneWayMessage implements RequestMessage {
@Override
public int encodedLength() {
- return Encoders.ByteArrays.encodedLength(message);
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 4;
}
@Override
public void encode(ByteBuf buf) {
- Encoders.ByteArrays.encode(buf, message);
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
}
public static OneWayMessage decode(ByteBuf buf) {
- byte[] message = Encoders.ByteArrays.decode(buf);
- return new OneWayMessage(message);
+ // See comment in encodedLength().
+ buf.readInt();
+ return new OneWayMessage(new NettyManagedBuffer(buf.retain()));
}
@Override
public int hashCode() {
- return Arrays.hashCode(message);
+ return Objects.hashCode(body());
}
@Override
public boolean equals(Object other) {
if (other instanceof OneWayMessage) {
OneWayMessage o = (OneWayMessage) other;
- return Arrays.equals(message, o.message);
+ return super.equals(o);
}
return false;
}
@@ -69,7 +74,7 @@ public final class OneWayMessage implements RequestMessage {
@Override
public String toString() {
return Objects.toStringHelper(this)
- .add("message", message)
+ .add("body", body())
.toString();
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
index 2dfc7876ba..a76624ef5d 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
@@ -21,7 +21,7 @@ import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
/** Response to {@link RpcRequest} for a failed RPC. */
-public final class RpcFailure implements ResponseMessage {
+public final class RpcFailure extends AbstractMessage implements ResponseMessage {
public final long requestId;
public final String errorString;
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
index 745039db74..96213794a8 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
@@ -17,26 +17,25 @@
package org.apache.spark.network.protocol;
-import java.util.Arrays;
-
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
/**
* A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}.
* This will correspond to a single
* {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
*/
-public final class RpcRequest implements RequestMessage {
+public final class RpcRequest extends AbstractMessage implements RequestMessage {
/** Used to link an RPC request with its response. */
public final long requestId;
- /** Serialized message to send to remote RpcHandler. */
- public final byte[] message;
-
- public RpcRequest(long requestId, byte[] message) {
+ public RpcRequest(long requestId, ManagedBuffer message) {
+ super(message, true);
this.requestId = requestId;
- this.message = message;
}
@Override
@@ -44,31 +43,36 @@ public final class RpcRequest implements RequestMessage {
@Override
public int encodedLength() {
- return 8 + Encoders.ByteArrays.encodedLength(message);
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 8 + 4;
}
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- Encoders.ByteArrays.encode(buf, message);
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
}
public static RpcRequest decode(ByteBuf buf) {
long requestId = buf.readLong();
- byte[] message = Encoders.ByteArrays.decode(buf);
- return new RpcRequest(requestId, message);
+ // See comment in encodedLength().
+ buf.readInt();
+ return new RpcRequest(requestId, new NettyManagedBuffer(buf.retain()));
}
@Override
public int hashCode() {
- return Objects.hashCode(requestId, Arrays.hashCode(message));
+ return Objects.hashCode(requestId, body());
}
@Override
public boolean equals(Object other) {
if (other instanceof RpcRequest) {
RpcRequest o = (RpcRequest) other;
- return requestId == o.requestId && Arrays.equals(message, o.message);
+ return requestId == o.requestId && super.equals(o);
}
return false;
}
@@ -77,7 +81,7 @@ public final class RpcRequest implements RequestMessage {
public String toString() {
return Objects.toStringHelper(this)
.add("requestId", requestId)
- .add("message", message)
+ .add("body", body())
.toString();
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
index 1671cd444f..bae866e14a 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
@@ -17,49 +17,62 @@
package org.apache.spark.network.protocol;
-import java.util.Arrays;
-
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
/** Response to {@link RpcRequest} for a successful RPC. */
-public final class RpcResponse implements ResponseMessage {
+public final class RpcResponse extends AbstractResponseMessage {
public final long requestId;
- public final byte[] response;
- public RpcResponse(long requestId, byte[] response) {
+ public RpcResponse(long requestId, ManagedBuffer message) {
+ super(message, true);
this.requestId = requestId;
- this.response = response;
}
@Override
public Type type() { return Type.RpcResponse; }
@Override
- public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); }
+ public int encodedLength() {
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 8 + 4;
+ }
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- Encoders.ByteArrays.encode(buf, response);
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
+ }
+
+ @Override
+ public ResponseMessage createFailureResponse(String error) {
+ return new RpcFailure(requestId, error);
}
public static RpcResponse decode(ByteBuf buf) {
long requestId = buf.readLong();
- byte[] response = Encoders.ByteArrays.decode(buf);
- return new RpcResponse(requestId, response);
+ // See comment in encodedLength().
+ buf.readInt();
+ return new RpcResponse(requestId, new NettyManagedBuffer(buf.retain()));
}
@Override
public int hashCode() {
- return Objects.hashCode(requestId, Arrays.hashCode(response));
+ return Objects.hashCode(requestId, body());
}
@Override
public boolean equals(Object other) {
if (other instanceof RpcResponse) {
RpcResponse o = (RpcResponse) other;
- return requestId == o.requestId && Arrays.equals(response, o.response);
+ return requestId == o.requestId && super.equals(o);
}
return false;
}
@@ -68,7 +81,7 @@ public final class RpcResponse implements ResponseMessage {
public String toString() {
return Objects.toStringHelper(this)
.add("requestId", requestId)
- .add("response", response)
+ .add("body", body())
.toString();
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
index e3dade2ebf..26747ee55b 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
@@ -26,7 +26,7 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
/**
* Message indicating an error when transferring a stream.
*/
-public final class StreamFailure implements ResponseMessage {
+public final class StreamFailure extends AbstractMessage implements ResponseMessage {
public final String streamId;
public final String error;
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
index 821e8f5388..35af5a84ba 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
@@ -29,7 +29,7 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
* The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before
* the data can be streamed.
*/
-public final class StreamRequest implements RequestMessage {
+public final class StreamRequest extends AbstractMessage implements RequestMessage {
public final String streamId;
public StreamRequest(String streamId) {
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
index ac5ab9a323..51b899930f 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
@@ -30,15 +30,15 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
* sender. The receiver is expected to set a temporary channel handler that will consume the
* number of bytes this message says the stream has.
*/
-public final class StreamResponse extends ResponseWithBody {
- public final String streamId;
- public final long byteCount;
+public final class StreamResponse extends AbstractResponseMessage {
+ public final String streamId;
+ public final long byteCount;
- public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) {
- super(buffer, false);
- this.streamId = streamId;
- this.byteCount = byteCount;
- }
+ public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) {
+ super(buffer, false);
+ this.streamId = streamId;
+ this.byteCount = byteCount;
+ }
@Override
public Type type() { return Type.StreamResponse; }
@@ -68,7 +68,7 @@ public final class StreamResponse extends ResponseWithBody {
@Override
public int hashCode() {
- return Objects.hashCode(byteCount, streamId);
+ return Objects.hashCode(byteCount, streamId, body());
}
@Override
@@ -85,6 +85,7 @@ public final class StreamResponse extends ResponseWithBody {
return Objects.toStringHelper(this)
.add("streamId", streamId)
.add("byteCount", byteCount)
+ .add("body", body())
.toString();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index 69923769d4..68381037d6 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.sasl;
+import java.io.IOException;
+import java.nio.ByteBuffer;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
@@ -28,6 +30,7 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.TransportConf;
/**
@@ -70,11 +73,12 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
while (!saslClient.isComplete()) {
SaslMessage msg = new SaslMessage(appId, payload);
- ByteBuf buf = Unpooled.buffer(msg.encodedLength());
+ ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size());
msg.encode(buf);
+ buf.writeBytes(msg.body().nioByteBuffer());
- byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs());
- payload = saslClient.response(response);
+ ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs());
+ payload = saslClient.response(JavaUtils.bufferToArray(response));
}
client.setClientId(appId);
@@ -88,6 +92,8 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
saslClient = null;
logger.debug("Channel {} configured for SASL encryption.", client);
}
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
} finally {
if (saslClient != null) {
try {
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
index cad76ab7aa..e52b526f09 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
@@ -18,38 +18,50 @@
package org.apache.spark.network.sasl;
import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
-import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
import org.apache.spark.network.protocol.Encoders;
+import org.apache.spark.network.protocol.AbstractMessage;
/**
* Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
* with the given appId. This appId allows a single SaslRpcHandler to multiplex different
* applications which may be using different sets of credentials.
*/
-class SaslMessage implements Encodable {
+class SaslMessage extends AbstractMessage {
/** Serialization tag used to catch incorrect payloads. */
private static final byte TAG_BYTE = (byte) 0xEA;
public final String appId;
- public final byte[] payload;
- public SaslMessage(String appId, byte[] payload) {
+ public SaslMessage(String appId, byte[] message) {
+ this(appId, Unpooled.wrappedBuffer(message));
+ }
+
+ public SaslMessage(String appId, ByteBuf message) {
+ super(new NettyManagedBuffer(message), true);
this.appId = appId;
- this.payload = payload;
}
@Override
+ public Type type() { return Type.User; }
+
+ @Override
public int encodedLength() {
- return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload);
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 1 + Encoders.Strings.encodedLength(appId) + 4;
}
@Override
public void encode(ByteBuf buf) {
buf.writeByte(TAG_BYTE);
Encoders.Strings.encode(buf, appId);
- Encoders.ByteArrays.encode(buf, payload);
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
}
public static SaslMessage decode(ByteBuf buf) {
@@ -59,7 +71,8 @@ class SaslMessage implements Encodable {
}
String appId = Encoders.Strings.decode(buf);
- byte[] payload = Encoders.ByteArrays.decode(buf);
- return new SaslMessage(appId, payload);
+ // See comment in encodedLength().
+ buf.readInt();
+ return new SaslMessage(appId, buf.retain());
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index 830db94b89..c215bd9d15 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -17,8 +17,11 @@
package org.apache.spark.network.sasl;
+import java.io.IOException;
+import java.nio.ByteBuffer;
import javax.security.sasl.Sasl;
+import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import org.slf4j.Logger;
@@ -28,6 +31,7 @@ import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.TransportConf;
/**
@@ -70,14 +74,20 @@ class SaslRpcHandler extends RpcHandler {
}
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
if (isComplete) {
// Authentication complete, delegate to base handler.
delegate.receive(client, message, callback);
return;
}
- SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message));
+ ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
+ SaslMessage saslMessage;
+ try {
+ saslMessage = SaslMessage.decode(nettyBuf);
+ } finally {
+ nettyBuf.release();
+ }
if (saslServer == null) {
// First message in the handshake, setup the necessary state.
@@ -86,8 +96,14 @@ class SaslRpcHandler extends RpcHandler {
conf.saslServerAlwaysEncrypt());
}
- byte[] response = saslServer.response(saslMessage.payload);
- callback.onSuccess(response);
+ byte[] response;
+ try {
+ response = saslServer.response(JavaUtils.bufferToArray(
+ saslMessage.body().nioByteBuffer()));
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
+ }
+ callback.onSuccess(ByteBuffer.wrap(response));
// Setup encryption after the SASL response is sent, otherwise the client can't parse the
// response. It's ok to change the channel pipeline here since we are processing an incoming
@@ -109,7 +125,7 @@ class SaslRpcHandler extends RpcHandler {
}
@Override
- public void receive(TransportClient client, byte[] message) {
+ public void receive(TransportClient client, ByteBuffer message) {
delegate.receive(client, message);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
index b80c15106e..3843406b27 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
@@ -26,7 +26,7 @@ import org.apache.spark.network.protocol.Message;
*/
public abstract class MessageHandler<T extends Message> {
/** Handles the receipt of a single message. */
- public abstract void handle(T message);
+ public abstract void handle(T message) throws Exception;
/** Invoked when an exception was caught on the Channel. */
public abstract void exceptionCaught(Throwable cause);
diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
index 1502b7489e..6ed61da5c7 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
@@ -1,5 +1,3 @@
-package org.apache.spark.network.server;
-
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
@@ -17,6 +15,10 @@ package org.apache.spark.network.server;
* limitations under the License.
*/
+package org.apache.spark.network.server;
+
+import java.nio.ByteBuffer;
+
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
@@ -29,7 +31,7 @@ public class NoOpRpcHandler extends RpcHandler {
}
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
throw new UnsupportedOperationException("Cannot handle messages");
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
index 1a11f7b382..ee1c683699 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.server;
+import java.nio.ByteBuffer;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -44,7 +46,7 @@ public abstract class RpcHandler {
*/
public abstract void receive(
TransportClient client,
- byte[] message,
+ ByteBuffer message,
RpcResponseCallback callback);
/**
@@ -62,7 +64,7 @@ public abstract class RpcHandler {
* of this RPC. This will always be the exact same object for a particular channel.
* @param message The serialized bytes of the RPC.
*/
- public void receive(TransportClient client, byte[] message) {
+ public void receive(TransportClient client, ByteBuffer message) {
receive(client, message, ONE_WAY_CALLBACK);
}
@@ -79,7 +81,7 @@ public abstract class RpcHandler {
private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class);
@Override
- public void onSuccess(byte[] response) {
+ public void onSuccess(ByteBuffer response) {
logger.warn("Response provided for one-way RPC.");
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
index 3164e00679..09435bcbab 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -99,7 +99,7 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
}
@Override
- public void channelRead0(ChannelHandlerContext ctx, Message request) {
+ public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
if (request instanceof RequestMessage) {
requestHandler.handle((RequestMessage) request);
} else {
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index db18ea77d1..c864d7ce16 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.server;
+import java.nio.ByteBuffer;
+
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import io.netty.channel.Channel;
@@ -26,6 +28,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.protocol.ChunkFetchRequest;
@@ -143,10 +146,10 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
private void processRpcRequest(final RpcRequest req) {
try {
- rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() {
+ rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() {
@Override
- public void onSuccess(byte[] response) {
- respond(new RpcResponse(req.requestId, response));
+ public void onSuccess(ByteBuffer response) {
+ respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
}
@Override
@@ -157,14 +160,18 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
} catch (Exception e) {
logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
+ } finally {
+ req.body().release();
}
}
private void processOneWayMessage(OneWayMessage req) {
try {
- rpcHandler.receive(reverseClient, req.message);
+ rpcHandler.receive(reverseClient, req.body().nioByteBuffer());
} catch (Exception e) {
logger.error("Error while invoking RpcHandler#receive() for one-way message.", e);
+ } finally {
+ req.body().release();
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
index 7d27439cfd..b3d8e0cd7c 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -132,7 +132,7 @@ public class JavaUtils {
return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile());
}
- private static final ImmutableMap<String, TimeUnit> timeSuffixes =
+ private static final ImmutableMap<String, TimeUnit> timeSuffixes =
ImmutableMap.<String, TimeUnit>builder()
.put("us", TimeUnit.MICROSECONDS)
.put("ms", TimeUnit.MILLISECONDS)
@@ -164,32 +164,32 @@ public class JavaUtils {
*/
private static long parseTimeString(String str, TimeUnit unit) {
String lower = str.toLowerCase().trim();
-
+
try {
Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower);
if (!m.matches()) {
throw new NumberFormatException("Failed to parse time string: " + str);
}
-
+
long val = Long.parseLong(m.group(1));
String suffix = m.group(2);
-
+
// Check for invalid suffixes
if (suffix != null && !timeSuffixes.containsKey(suffix)) {
throw new NumberFormatException("Invalid suffix: \"" + suffix + "\"");
}
-
+
// If suffix is valid use that, otherwise none was provided and use the default passed
return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit);
} catch (NumberFormatException e) {
String timeError = "Time must be specified as seconds (s), " +
"milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " +
"E.g. 50s, 100ms, or 250us.";
-
+
throw new NumberFormatException(timeError + "\n" + e.getMessage());
}
}
-
+
/**
* Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If
* no suffix is provided, the passed number is assumed to be in ms.
@@ -205,10 +205,10 @@ public class JavaUtils {
public static long timeStringAsSec(String str) {
return parseTimeString(str, TimeUnit.SECONDS);
}
-
+
/**
* Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for
- * internal use. If no suffix is provided a direct conversion of the provided default is
+ * internal use. If no suffix is provided a direct conversion of the provided default is
* attempted.
*/
private static long parseByteString(String str, ByteUnit unit) {
@@ -217,7 +217,7 @@ public class JavaUtils {
try {
Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower);
Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower);
-
+
if (m.matches()) {
long val = Long.parseLong(m.group(1));
String suffix = m.group(2);
@@ -228,14 +228,14 @@ public class JavaUtils {
}
// If suffix is valid use that, otherwise none was provided and use the default passed
- return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit);
+ return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit);
} else if (fractionMatcher.matches()) {
- throw new NumberFormatException("Fractional values are not supported. Input was: "
+ throw new NumberFormatException("Fractional values are not supported. Input was: "
+ fractionMatcher.group(1));
} else {
- throw new NumberFormatException("Failed to parse byte string: " + str);
+ throw new NumberFormatException("Failed to parse byte string: " + str);
}
-
+
} catch (NumberFormatException e) {
String timeError = "Size must be specified as bytes (b), " +
"kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " +
@@ -248,7 +248,7 @@ public class JavaUtils {
/**
* Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for
* internal use.
- *
+ *
* If no suffix is provided, the passed number is assumed to be in bytes.
*/
public static long byteStringAsBytes(String str) {
@@ -264,7 +264,7 @@ public class JavaUtils {
public static long byteStringAsKb(String str) {
return parseByteString(str, ByteUnit.KiB);
}
-
+
/**
* Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for
* internal use.
@@ -284,4 +284,20 @@ public class JavaUtils {
public static long byteStringAsGb(String str) {
return parseByteString(str, ByteUnit.GiB);
}
+
+ /**
+ * Returns a byte array with the buffer's contents, trying to avoid copying the data if
+ * possible.
+ */
+ public static byte[] bufferToArray(ByteBuffer buffer) {
+ if (buffer.hasArray() && buffer.arrayOffset() == 0 &&
+ buffer.array().length == buffer.remaining()) {
+ return buffer.array();
+ } else {
+ byte[] bytes = new byte[buffer.remaining()];
+ buffer.get(bytes);
+ return bytes;
+ }
+ }
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
index 5889562dd9..a466c72915 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
@@ -17,9 +17,13 @@
package org.apache.spark.network.util;
+import java.util.Iterator;
+import java.util.LinkedList;
+
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
+import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
@@ -44,84 +48,138 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
public static final String HANDLER_NAME = "frameDecoder";
private static final int LENGTH_SIZE = 8;
private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE;
+ private static final int UNKNOWN_FRAME_SIZE = -1;
+
+ private final LinkedList<ByteBuf> buffers = new LinkedList<>();
+ private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE);
- private CompositeByteBuf buffer;
+ private long totalSize = 0;
+ private long nextFrameSize = UNKNOWN_FRAME_SIZE;
private volatile Interceptor interceptor;
@Override
public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
ByteBuf in = (ByteBuf) data;
+ buffers.add(in);
+ totalSize += in.readableBytes();
+
+ while (!buffers.isEmpty()) {
+ // First, feed the interceptor, and if it's still, active, try again.
+ if (interceptor != null) {
+ ByteBuf first = buffers.getFirst();
+ int available = first.readableBytes();
+ if (feedInterceptor(first)) {
+ assert !first.isReadable() : "Interceptor still active but buffer has data.";
+ }
- if (buffer == null) {
- buffer = in.alloc().compositeBuffer();
- }
-
- buffer.addComponent(in).writerIndex(buffer.writerIndex() + in.readableBytes());
-
- while (buffer.isReadable()) {
- discardReadBytes();
- if (!feedInterceptor()) {
+ int read = available - first.readableBytes();
+ if (read == available) {
+ buffers.removeFirst().release();
+ }
+ totalSize -= read;
+ } else {
+ // Interceptor is not active, so try to decode one frame.
ByteBuf frame = decodeNext();
if (frame == null) {
break;
}
-
ctx.fireChannelRead(frame);
}
}
-
- discardReadBytes();
}
- private void discardReadBytes() {
- // If the buffer's been retained by downstream code, then make a copy of the remaining
- // bytes into a new buffer. Otherwise, just discard stale components.
- if (buffer.refCnt() > 1) {
- CompositeByteBuf newBuffer = buffer.alloc().compositeBuffer();
+ private long decodeFrameSize() {
+ if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < LENGTH_SIZE) {
+ return nextFrameSize;
+ }
- if (buffer.readableBytes() > 0) {
- ByteBuf spillBuf = buffer.alloc().buffer(buffer.readableBytes());
- spillBuf.writeBytes(buffer);
- newBuffer.addComponent(spillBuf).writerIndex(spillBuf.readableBytes());
+ // We know there's enough data. If the first buffer contains all the data, great. Otherwise,
+ // hold the bytes for the frame length in a composite buffer until we have enough data to read
+ // the frame size. Normally, it should be rare to need more than one buffer to read the frame
+ // size.
+ ByteBuf first = buffers.getFirst();
+ if (first.readableBytes() >= LENGTH_SIZE) {
+ nextFrameSize = first.readLong() - LENGTH_SIZE;
+ totalSize -= LENGTH_SIZE;
+ if (!first.isReadable()) {
+ buffers.removeFirst().release();
}
+ return nextFrameSize;
+ }
- buffer.release();
- buffer = newBuffer;
- } else {
- buffer.discardReadComponents();
+ while (frameLenBuf.readableBytes() < LENGTH_SIZE) {
+ ByteBuf next = buffers.getFirst();
+ int toRead = Math.min(next.readableBytes(), LENGTH_SIZE - frameLenBuf.readableBytes());
+ frameLenBuf.writeBytes(next, toRead);
+ if (!next.isReadable()) {
+ buffers.removeFirst().release();
+ }
}
+
+ nextFrameSize = frameLenBuf.readLong() - LENGTH_SIZE;
+ totalSize -= LENGTH_SIZE;
+ frameLenBuf.clear();
+ return nextFrameSize;
}
private ByteBuf decodeNext() throws Exception {
- if (buffer.readableBytes() < LENGTH_SIZE) {
+ long frameSize = decodeFrameSize();
+ if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) {
return null;
}
- int frameLen = (int) buffer.readLong() - LENGTH_SIZE;
- if (buffer.readableBytes() < frameLen) {
- buffer.readerIndex(buffer.readerIndex() - LENGTH_SIZE);
- return null;
+ // Reset size for next frame.
+ nextFrameSize = UNKNOWN_FRAME_SIZE;
+
+ Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize);
+ Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize);
+
+ // If the first buffer holds the entire frame, return it.
+ int remaining = (int) frameSize;
+ if (buffers.getFirst().readableBytes() >= remaining) {
+ return nextBufferForFrame(remaining);
}
- Preconditions.checkArgument(frameLen < MAX_FRAME_SIZE, "Too large frame: %s", frameLen);
- Preconditions.checkArgument(frameLen > 0, "Frame length should be positive: %s", frameLen);
+ // Otherwise, create a composite buffer.
+ CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer();
+ while (remaining > 0) {
+ ByteBuf next = nextBufferForFrame(remaining);
+ remaining -= next.readableBytes();
+ frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes());
+ }
+ assert remaining == 0;
+ return frame;
+ }
+
+ /**
+ * Takes the first buffer in the internal list, and either adjust it to fit in the frame
+ * (by taking a slice out of it) or remove it from the internal list.
+ */
+ private ByteBuf nextBufferForFrame(int bytesToRead) {
+ ByteBuf buf = buffers.getFirst();
+ ByteBuf frame;
+
+ if (buf.readableBytes() > bytesToRead) {
+ frame = buf.retain().readSlice(bytesToRead);
+ totalSize -= bytesToRead;
+ } else {
+ frame = buf;
+ buffers.removeFirst();
+ totalSize -= frame.readableBytes();
+ }
- ByteBuf frame = buffer.readSlice(frameLen);
- frame.retain();
return frame;
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
- if (buffer != null) {
- if (buffer.isReadable()) {
- feedInterceptor();
- }
- buffer.release();
+ for (ByteBuf b : buffers) {
+ b.release();
}
if (interceptor != null) {
interceptor.channelInactive();
}
+ frameLenBuf.release();
super.channelInactive(ctx);
}
@@ -141,8 +199,8 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
/**
* @return Whether the interceptor is still active after processing the data.
*/
- private boolean feedInterceptor() throws Exception {
- if (interceptor != null && !interceptor.handle(buffer)) {
+ private boolean feedInterceptor(ByteBuf buf) throws Exception {
+ if (interceptor != null && !interceptor.handle(buf)) {
interceptor = null;
}
return interceptor != null;
diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
index 50a324e293..70c849d60e 100644
--- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
@@ -107,7 +107,10 @@ public class ChunkFetchIntegrationSuite {
};
RpcHandler handler = new RpcHandler() {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
throw new UnsupportedOperationException();
}
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index 1aa20900ff..6c8dd742f4 100644
--- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -82,10 +82,10 @@ public class ProtocolSuite {
@Test
public void requests() {
testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2)));
- testClientToServer(new RpcRequest(12345, new byte[0]));
- testClientToServer(new RpcRequest(12345, new byte[100]));
+ testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0)));
+ testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10)));
testClientToServer(new StreamRequest("abcde"));
- testClientToServer(new OneWayMessage(new byte[100]));
+ testClientToServer(new OneWayMessage(new TestManagedBuffer(10)));
}
@Test
@@ -94,8 +94,8 @@ public class ProtocolSuite {
testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0)));
testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error"));
testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), ""));
- testServerToClient(new RpcResponse(12345, new byte[0]));
- testServerToClient(new RpcResponse(12345, new byte[1000]));
+ testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0)));
+ testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100)));
testServerToClient(new RpcFailure(0, "this is an error"));
testServerToClient(new RpcFailure(0, ""));
// Note: buffer size must be "0" since StreamResponse's buffer is written differently to the
diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
index 42955ef692..f9b5bf96d6 100644
--- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
@@ -31,6 +31,7 @@ import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.junit.*;
+import static org.junit.Assert.*;
import java.io.IOException;
import java.nio.ByteBuffer;
@@ -84,13 +85,16 @@ public class RequestTimeoutIntegrationSuite {
@Test
public void timeoutInactiveRequests() throws Exception {
final Semaphore semaphore = new Semaphore(1);
- final byte[] response = new byte[16];
+ final int responseSize = 16;
RpcHandler handler = new RpcHandler() {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
try {
semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
- callback.onSuccess(response);
+ callback.onSuccess(ByteBuffer.allocate(responseSize));
} catch (InterruptedException e) {
// do nothing
}
@@ -110,15 +114,15 @@ public class RequestTimeoutIntegrationSuite {
// First completes quickly (semaphore starts at 1).
TestCallback callback0 = new TestCallback();
synchronized (callback0) {
- client.sendRpc(new byte[0], callback0);
+ client.sendRpc(ByteBuffer.allocate(0), callback0);
callback0.wait(FOREVER);
- assert (callback0.success.length == response.length);
+ assertEquals(responseSize, callback0.successLength);
}
// Second times out after 2 seconds, with slack. Must be IOException.
TestCallback callback1 = new TestCallback();
synchronized (callback1) {
- client.sendRpc(new byte[0], callback1);
+ client.sendRpc(ByteBuffer.allocate(0), callback1);
callback1.wait(4 * 1000);
assert (callback1.failure != null);
assert (callback1.failure instanceof IOException);
@@ -131,13 +135,16 @@ public class RequestTimeoutIntegrationSuite {
@Test
public void timeoutCleanlyClosesClient() throws Exception {
final Semaphore semaphore = new Semaphore(0);
- final byte[] response = new byte[16];
+ final int responseSize = 16;
RpcHandler handler = new RpcHandler() {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
try {
semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
- callback.onSuccess(response);
+ callback.onSuccess(ByteBuffer.allocate(responseSize));
} catch (InterruptedException e) {
// do nothing
}
@@ -158,7 +165,7 @@ public class RequestTimeoutIntegrationSuite {
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
TestCallback callback0 = new TestCallback();
synchronized (callback0) {
- client0.sendRpc(new byte[0], callback0);
+ client0.sendRpc(ByteBuffer.allocate(0), callback0);
callback0.wait(FOREVER);
assert (callback0.failure instanceof IOException);
assert (!client0.isActive());
@@ -170,10 +177,10 @@ public class RequestTimeoutIntegrationSuite {
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
TestCallback callback1 = new TestCallback();
synchronized (callback1) {
- client1.sendRpc(new byte[0], callback1);
+ client1.sendRpc(ByteBuffer.allocate(0), callback1);
callback1.wait(FOREVER);
- assert (callback1.success.length == response.length);
- assert (callback1.failure == null);
+ assertEquals(responseSize, callback1.successLength);
+ assertNull(callback1.failure);
}
}
@@ -191,7 +198,10 @@ public class RequestTimeoutIntegrationSuite {
};
RpcHandler handler = new RpcHandler() {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
throw new UnsupportedOperationException();
}
@@ -218,9 +228,10 @@ public class RequestTimeoutIntegrationSuite {
synchronized (callback0) {
// not complete yet, but should complete soon
- assert (callback0.success == null && callback0.failure == null);
+ assertEquals(-1, callback0.successLength);
+ assertNull(callback0.failure);
callback0.wait(2 * 1000);
- assert (callback0.failure instanceof IOException);
+ assertTrue(callback0.failure instanceof IOException);
}
synchronized (callback1) {
@@ -235,13 +246,13 @@ public class RequestTimeoutIntegrationSuite {
*/
class TestCallback implements RpcResponseCallback, ChunkReceivedCallback {
- byte[] success;
+ int successLength = -1;
Throwable failure;
@Override
- public void onSuccess(byte[] response) {
+ public void onSuccess(ByteBuffer response) {
synchronized(this) {
- success = response;
+ successLength = response.remaining();
this.notifyAll();
}
}
@@ -258,7 +269,7 @@ public class RequestTimeoutIntegrationSuite {
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
synchronized(this) {
try {
- success = buffer.nioByteBuffer().array();
+ successLength = buffer.nioByteBuffer().remaining();
this.notifyAll();
} catch (IOException e) {
// weird
diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
index 88fa2258bb..9e9be98c14 100644
--- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.network;
+import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
@@ -26,7 +27,6 @@ import java.util.Set;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
-import com.google.common.base.Charsets;
import com.google.common.collect.Sets;
import org.junit.AfterClass;
import org.junit.BeforeClass;
@@ -41,6 +41,7 @@ import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
@@ -55,11 +56,14 @@ public class RpcIntegrationSuite {
TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
rpcHandler = new RpcHandler() {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
- String msg = new String(message, Charsets.UTF_8);
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ String msg = JavaUtils.bytesToString(message);
String[] parts = msg.split("/");
if (parts[0].equals("hello")) {
- callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8));
+ callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!"));
} else if (parts[0].equals("return error")) {
callback.onFailure(new RuntimeException("Returned: " + parts[1]));
} else if (parts[0].equals("throw error")) {
@@ -68,9 +72,8 @@ public class RpcIntegrationSuite {
}
@Override
- public void receive(TransportClient client, byte[] message) {
- String msg = new String(message, Charsets.UTF_8);
- oneWayMsgs.add(msg);
+ public void receive(TransportClient client, ByteBuffer message) {
+ oneWayMsgs.add(JavaUtils.bytesToString(message));
}
@Override
@@ -103,8 +106,9 @@ public class RpcIntegrationSuite {
RpcResponseCallback callback = new RpcResponseCallback() {
@Override
- public void onSuccess(byte[] message) {
- res.successMessages.add(new String(message, Charsets.UTF_8));
+ public void onSuccess(ByteBuffer message) {
+ String response = JavaUtils.bytesToString(message);
+ res.successMessages.add(response);
sem.release();
}
@@ -116,7 +120,7 @@ public class RpcIntegrationSuite {
};
for (String command : commands) {
- client.sendRpc(command.getBytes(Charsets.UTF_8), callback);
+ client.sendRpc(JavaUtils.stringToBytes(command), callback);
}
if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) {
@@ -173,7 +177,7 @@ public class RpcIntegrationSuite {
final String message = "no reply";
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
try {
- client.send(message.getBytes(Charsets.UTF_8));
+ client.send(JavaUtils.stringToBytes(message));
assertEquals(0, client.getHandler().numOutstandingRequests());
// Make sure the message arrives.
diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
index 538f3efe8d..9c49556927 100644
--- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
@@ -116,7 +116,10 @@ public class StreamSuite {
};
RpcHandler handler = new RpcHandler() {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
throw new UnsupportedOperationException();
}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
index 30144f4a9f..128f7cba74 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
@@ -17,6 +17,8 @@
package org.apache.spark.network;
+import java.nio.ByteBuffer;
+
import io.netty.channel.Channel;
import io.netty.channel.local.LocalChannel;
import org.junit.Test;
@@ -27,6 +29,7 @@ import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.*;
import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallback;
@@ -42,7 +45,7 @@ import org.apache.spark.network.util.TransportFrameDecoder;
public class TransportResponseHandlerSuite {
@Test
- public void handleSuccessfulFetch() {
+ public void handleSuccessfulFetch() throws Exception {
StreamChunkId streamChunkId = new StreamChunkId(1, 0);
TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
@@ -56,7 +59,7 @@ public class TransportResponseHandlerSuite {
}
@Test
- public void handleFailedFetch() {
+ public void handleFailedFetch() throws Exception {
StreamChunkId streamChunkId = new StreamChunkId(1, 0);
TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
@@ -69,7 +72,7 @@ public class TransportResponseHandlerSuite {
}
@Test
- public void clearAllOutstandingRequests() {
+ public void clearAllOutstandingRequests() throws Exception {
TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
handler.addFetchRequest(new StreamChunkId(1, 0), callback);
@@ -88,23 +91,24 @@ public class TransportResponseHandlerSuite {
}
@Test
- public void handleSuccessfulRPC() {
+ public void handleSuccessfulRPC() throws Exception {
TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
RpcResponseCallback callback = mock(RpcResponseCallback.class);
handler.addRpcRequest(12345, callback);
assertEquals(1, handler.numOutstandingRequests());
- handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored
+ // This response should be ignored.
+ handler.handle(new RpcResponse(54321, new NioManagedBuffer(ByteBuffer.allocate(7))));
assertEquals(1, handler.numOutstandingRequests());
- byte[] arr = new byte[10];
- handler.handle(new RpcResponse(12345, arr));
- verify(callback, times(1)).onSuccess(eq(arr));
+ ByteBuffer resp = ByteBuffer.allocate(10);
+ handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp)));
+ verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10)));
assertEquals(0, handler.numOutstandingRequests());
}
@Test
- public void handleFailedRPC() {
+ public void handleFailedRPC() throws Exception {
TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
RpcResponseCallback callback = mock(RpcResponseCallback.class);
handler.addRpcRequest(12345, callback);
@@ -119,7 +123,7 @@ public class TransportResponseHandlerSuite {
}
@Test
- public void testActiveStreams() {
+ public void testActiveStreams() throws Exception {
Channel c = new LocalChannel();
c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
TransportResponseHandler handler = new TransportResponseHandler(c);
diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index a6f180bc40..751516b9d8 100644
--- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -22,7 +22,7 @@ import static org.mockito.Mockito.*;
import java.io.File;
import java.lang.reflect.Method;
-import java.nio.charset.StandardCharsets;
+import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
@@ -57,6 +57,7 @@ import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
@@ -123,39 +124,53 @@ public class SparkSaslSuite {
}
@Test
- public void testSaslAuthentication() throws Exception {
+ public void testSaslAuthentication() throws Throwable {
testBasicSasl(false);
}
@Test
- public void testSaslEncryption() throws Exception {
+ public void testSaslEncryption() throws Throwable {
testBasicSasl(true);
}
- private void testBasicSasl(boolean encrypt) throws Exception {
+ private void testBasicSasl(boolean encrypt) throws Throwable {
RpcHandler rpcHandler = mock(RpcHandler.class);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) {
- byte[] message = (byte[]) invocation.getArguments()[1];
+ ByteBuffer message = (ByteBuffer) invocation.getArguments()[1];
RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2];
- assertEquals("Ping", new String(message, StandardCharsets.UTF_8));
- cb.onSuccess("Pong".getBytes(StandardCharsets.UTF_8));
+ assertEquals("Ping", JavaUtils.bytesToString(message));
+ cb.onSuccess(JavaUtils.stringToBytes("Pong"));
return null;
}
})
.when(rpcHandler)
- .receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class));
+ .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class));
SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
try {
- byte[] response = ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8),
- TimeUnit.SECONDS.toMillis(10));
- assertEquals("Pong", new String(response, StandardCharsets.UTF_8));
+ ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
+ TimeUnit.SECONDS.toMillis(10));
+ assertEquals("Pong", JavaUtils.bytesToString(response));
} finally {
ctx.close();
// There should be 2 terminated events; one for the client, one for the server.
- verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class));
+ Throwable error = null;
+ long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
+ while (deadline > System.nanoTime()) {
+ try {
+ verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class));
+ error = null;
+ break;
+ } catch (Throwable t) {
+ error = t;
+ TimeUnit.MILLISECONDS.sleep(10);
+ }
+ }
+ if (error != null) {
+ throw error;
+ }
}
}
@@ -325,8 +340,8 @@ public class SparkSaslSuite {
SaslTestCtx ctx = null;
try {
ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
- ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8),
- TimeUnit.SECONDS.toMillis(10));
+ ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
+ TimeUnit.SECONDS.toMillis(10));
fail("Should have failed to send RPC to server.");
} catch (Exception e) {
assertFalse(e.getCause() instanceof TimeoutException);
diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
index 19475c21ff..d4de4a941d 100644
--- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
@@ -118,6 +118,27 @@ public class TransportFrameDecoderSuite {
}
}
+ @Test
+ public void testSplitLengthField() throws Exception {
+ byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
+ ByteBuf buf = Unpooled.buffer(frame.length + 8);
+ buf.writeLong(frame.length + 8);
+ buf.writeBytes(frame);
+
+ TransportFrameDecoder decoder = new TransportFrameDecoder();
+ ChannelHandlerContext ctx = mockChannelHandlerContext();
+ try {
+ decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain());
+ verify(ctx, never()).fireChannelRead(any(ByteBuf.class));
+ decoder.channelRead(ctx, buf);
+ verify(ctx).fireChannelRead(any(ByteBuf.class));
+ assertEquals(0, buf.refCnt());
+ } finally {
+ decoder.channelInactive(ctx);
+ release(buf);
+ }
+ }
+
@Test(expected = IllegalArgumentException.class)
public void testNegativeFrameSize() throws Exception {
testInvalidFrame(-1);
@@ -183,7 +204,7 @@ public class TransportFrameDecoderSuite {
try {
decoder.channelRead(ctx, frame);
} finally {
- frame.release();
+ release(frame);
}
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 3ddf5c3c39..f22187a01d 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -19,6 +19,7 @@ package org.apache.spark.network.shuffle;
import java.io.File;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.List;
import com.google.common.annotations.VisibleForTesting;
@@ -66,8 +67,8 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
}
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
- BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message);
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
+ BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message);
handleMessage(msgObj, client, callback);
}
@@ -85,13 +86,13 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
}
long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator());
logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
- callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray());
+ callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer());
} else if (msgObj instanceof RegisterExecutor) {
RegisterExecutor msg = (RegisterExecutor) msgObj;
checkAuth(client, msg.appId);
blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
- callback.onSuccess(new byte[0]);
+ callback.onSuccess(ByteBuffer.wrap(new byte[0]));
} else {
throw new UnsupportedOperationException("Unexpected message: " + msgObj);
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index ef3a9dcc87..58ca87d9d3 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -18,6 +18,7 @@
package org.apache.spark.network.shuffle;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.List;
import com.google.common.base.Preconditions;
@@ -139,7 +140,7 @@ public class ExternalShuffleClient extends ShuffleClient {
checkInit();
TransportClient client = clientFactory.createUnmanagedClient(host, port);
try {
- byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray();
+ ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer();
client.sendRpcSync(registerMessage, 5000 /* timeoutMs */);
} finally {
client.close();
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index e653f5cb14..1b2ddbf1ed 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -17,6 +17,7 @@
package org.apache.spark.network.shuffle;
+import java.nio.ByteBuffer;
import java.util.Arrays;
import org.slf4j.Logger;
@@ -89,11 +90,11 @@ public class OneForOneBlockFetcher {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
- client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() {
+ client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
@Override
- public void onSuccess(byte[] response) {
+ public void onSuccess(ByteBuffer response) {
try {
- streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);
+ streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);
// Immediately request all chunks -- we expect that the total size of the request is
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
index 7543b6be4f..675820308b 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
@@ -18,6 +18,7 @@
package org.apache.spark.network.shuffle.mesos;
import java.io.IOException;
+import java.nio.ByteBuffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -54,11 +55,11 @@ public class MesosExternalShuffleClient extends ExternalShuffleClient {
public void registerDriverWithShuffleService(String host, int port) throws IOException {
checkInit();
- byte[] registerDriver = new RegisterDriver(appId).toByteArray();
+ ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer();
TransportClient client = clientFactory.createClient(host, port);
client.sendRpc(registerDriver, new RpcResponseCallback() {
@Override
- public void onSuccess(byte[] response) {
+ public void onSuccess(ByteBuffer response) {
logger.info("Successfully registered app " + appId + " with external shuffle service.");
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
index fcb52363e6..7fbe3384b4 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.shuffle.protocol;
+import java.nio.ByteBuffer;
+
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
@@ -53,7 +55,7 @@ public abstract class BlockTransferMessage implements Encodable {
// NB: Java does not support static methods in interfaces, so we must put this in a static class.
public static class Decoder {
/** Deserializes the 'type' byte followed by the message itself. */
- public static BlockTransferMessage fromByteArray(byte[] msg) {
+ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) {
ByteBuf buf = Unpooled.wrappedBuffer(msg);
byte type = buf.readByte();
switch (type) {
@@ -68,12 +70,12 @@ public abstract class BlockTransferMessage implements Encodable {
}
/** Serializes the 'type' byte followed by the message itself. */
- public byte[] toByteArray() {
+ public ByteBuffer toByteBuffer() {
// Allow room for encoded message, plus the type byte
ByteBuf buf = Unpooled.buffer(encodedLength() + 1);
buf.writeByte(type().id);
encode(buf);
assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes();
- return buf.array();
+ return buf.nioBuffer();
}
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index 1c2fa4d0d4..19c870aebb 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.network.sasl;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicReference;
@@ -52,6 +53,7 @@ import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
@@ -107,8 +109,8 @@ public class SaslIntegrationSuite {
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
String msg = "Hello, World!";
- byte[] resp = client.sendRpcSync(msg.getBytes(), TIMEOUT_MS);
- assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg
+ ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), TIMEOUT_MS);
+ assertEquals(msg, JavaUtils.bytesToString(resp));
}
@Test
@@ -136,7 +138,7 @@ public class SaslIntegrationSuite {
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
try {
- client.sendRpcSync(new byte[13], TIMEOUT_MS);
+ client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS);
fail("Should have failed");
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage"));
@@ -144,7 +146,7 @@ public class SaslIntegrationSuite {
try {
// Guessing the right tag byte doesn't magically get you in...
- client.sendRpcSync(new byte[] { (byte) 0xEA }, TIMEOUT_MS);
+ client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS);
fail("Should have failed");
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException"));
@@ -222,13 +224,13 @@ public class SaslIntegrationSuite {
new String[] { System.getProperty("java.io.tmpdir") }, 1,
"org.apache.spark.shuffle.sort.SortShuffleManager");
RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
- client1.sendRpcSync(regmsg.toByteArray(), TIMEOUT_MS);
+ client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
// Make a successful request to fetch blocks, which creates a new stream. But do not actually
// fetch any blocks, to keep the stream open.
OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
- byte[] response = client1.sendRpcSync(openMessage.toByteArray(), TIMEOUT_MS);
- StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);
+ ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS);
+ StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
long streamId = stream.streamId;
// Create a second client, authenticated with a different app ID, and try to read from
@@ -275,7 +277,7 @@ public class SaslIntegrationSuite {
/** RPC handler which simply responds with the message it received. */
public static class TestRpcHandler extends RpcHandler {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
callback.onSuccess(message);
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
index d65de9ca55..86c8609e70 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
@@ -36,7 +36,7 @@ public class BlockTransferMessagesSuite {
}
private void checkSerializeDeserialize(BlockTransferMessage msg) {
- BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray());
+ BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer());
assertEquals(msg, msg2);
assertEquals(msg.hashCode(), msg2.hashCode());
assertEquals(msg.toString(), msg2.toString());
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index e61390cf57..9379412155 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -60,12 +60,12 @@ public class ExternalShuffleBlockHandlerSuite {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort");
- byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray();
+ ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer();
handler.receive(client, registerMessage, callback);
verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config);
- verify(callback, times(1)).onSuccess((byte[]) any());
- verify(callback, never()).onFailure((Throwable) any());
+ verify(callback, times(1)).onSuccess(any(ByteBuffer.class));
+ verify(callback, never()).onFailure(any(Throwable.class));
}
@SuppressWarnings("unchecked")
@@ -77,17 +77,18 @@ public class ExternalShuffleBlockHandlerSuite {
ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker);
when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker);
- byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray();
+ ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" })
+ .toByteBuffer();
handler.receive(client, openBlocks, callback);
verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0");
verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1");
- ArgumentCaptor<byte[]> response = ArgumentCaptor.forClass(byte[].class);
+ ArgumentCaptor<ByteBuffer> response = ArgumentCaptor.forClass(ByteBuffer.class);
verify(callback, times(1)).onSuccess(response.capture());
verify(callback, never()).onFailure((Throwable) any());
StreamHandle handle =
- (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue());
+ (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue());
assertEquals(2, handle.numChunks);
@SuppressWarnings("unchecked")
@@ -104,7 +105,7 @@ public class ExternalShuffleBlockHandlerSuite {
public void testBadMessages() {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
- byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 };
+ ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 });
try {
handler.receive(client, unserializableMsg, callback);
fail("Should have thrown");
@@ -112,7 +113,7 @@ public class ExternalShuffleBlockHandlerSuite {
// pass
}
- byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray();
+ ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer();
try {
handler.receive(client, unexpectedMsg, callback);
fail("Should have thrown");
@@ -120,7 +121,7 @@ public class ExternalShuffleBlockHandlerSuite {
// pass
}
- verify(callback, never()).onSuccess((byte[]) any());
- verify(callback, never()).onFailure((Throwable) any());
+ verify(callback, never()).onSuccess(any(ByteBuffer.class));
+ verify(callback, never()).onFailure(any(Throwable.class));
}
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
index b35a6d685d..2590b9ce4c 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
@@ -134,14 +134,14 @@ public class OneForOneBlockFetcherSuite {
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
- BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray(
- (byte[]) invocationOnMock.getArguments()[0]);
+ BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer(
+ (ByteBuffer) invocationOnMock.getArguments()[0]);
RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1];
- callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray());
+ callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer());
assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message);
return null;
}
- }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any());
+ }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class));
// Respond to each chunk request with a single buffer from our blocks array.
final AtomicInteger expectedChunkIndex = new AtomicInteger(0);