aboutsummaryrefslogtreecommitdiff
path: root/network
diff options
context:
space:
mode:
Diffstat (limited to 'network')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java4
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java16
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java16
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java54
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java (renamed from network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java)16
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java8
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Message.java11
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java29
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java33
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java34
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java39
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java19
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java12
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java31
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java26
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java8
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java8
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java15
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java48
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java142
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java5
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java10
-rw-r--r--network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java51
-rw-r--r--network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java26
-rw-r--r--network/common/src/test/java/org/apache/spark/network/StreamSuite.java5
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java24
-rw-r--r--network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java43
-rw-r--r--network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java23
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java9
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java3
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java7
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java5
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java8
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java18
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java2
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java21
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java8
44 files changed, 565 insertions, 286 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
index 6ec960d795..47e93f9846 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
@@ -17,13 +17,15 @@
package org.apache.spark.network.client;
+import java.nio.ByteBuffer;
+
/**
* Callback for the result of a single RPC. This will be invoked once with either success or
* failure.
*/
public interface RpcResponseCallback {
/** Successful serialized result from server. */
- void onSuccess(byte[] response);
+ void onSuccess(ByteBuffer response);
/** Exception either propagated from server or raised on client side. */
void onFailure(Throwable e);
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index 8a58e7b245..c49ca4d5ee 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -20,6 +20,7 @@ package org.apache.spark.network.client;
import java.io.Closeable;
import java.io.IOException;
import java.net.SocketAddress;
+import java.nio.ByteBuffer;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
@@ -36,6 +37,7 @@ import io.netty.channel.ChannelFutureListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.protocol.ChunkFetchRequest;
import org.apache.spark.network.protocol.OneWayMessage;
import org.apache.spark.network.protocol.RpcRequest;
@@ -212,7 +214,7 @@ public class TransportClient implements Closeable {
* @param callback Callback to handle the RPC's reply.
* @return The RPC's id.
*/
- public long sendRpc(byte[] message, final RpcResponseCallback callback) {
+ public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) {
final String serverAddr = NettyUtils.getRemoteAddress(channel);
final long startTime = System.currentTimeMillis();
logger.trace("Sending RPC to {}", serverAddr);
@@ -220,7 +222,7 @@ public class TransportClient implements Closeable {
final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
handler.addRpcRequest(requestId, callback);
- channel.writeAndFlush(new RpcRequest(requestId, message)).addListener(
+ channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener(
new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
@@ -249,12 +251,12 @@ public class TransportClient implements Closeable {
* Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
* a specified timeout for a response.
*/
- public byte[] sendRpcSync(byte[] message, long timeoutMs) {
- final SettableFuture<byte[]> result = SettableFuture.create();
+ public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) {
+ final SettableFuture<ByteBuffer> result = SettableFuture.create();
sendRpc(message, new RpcResponseCallback() {
@Override
- public void onSuccess(byte[] response) {
+ public void onSuccess(ByteBuffer response) {
result.set(response);
}
@@ -279,8 +281,8 @@ public class TransportClient implements Closeable {
*
* @param message The message to send.
*/
- public void send(byte[] message) {
- channel.writeAndFlush(new OneWayMessage(message));
+ public void send(ByteBuffer message) {
+ channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message)));
}
/**
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index 4c15045363..23a8dba593 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -136,7 +136,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
}
@Override
- public void handle(ResponseMessage message) {
+ public void handle(ResponseMessage message) throws Exception {
String remoteAddress = NettyUtils.getRemoteAddress(channel);
if (message instanceof ChunkFetchSuccess) {
ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
@@ -144,11 +144,11 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
if (listener == null) {
logger.warn("Ignoring response for block {} from {} since it is not outstanding",
resp.streamChunkId, remoteAddress);
- resp.body.release();
+ resp.body().release();
} else {
outstandingFetches.remove(resp.streamChunkId);
- listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body);
- resp.body.release();
+ listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
+ resp.body().release();
}
} else if (message instanceof ChunkFetchFailure) {
ChunkFetchFailure resp = (ChunkFetchFailure) message;
@@ -166,10 +166,14 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
if (listener == null) {
logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
- resp.requestId, remoteAddress, resp.response.length);
+ resp.requestId, remoteAddress, resp.body().size());
} else {
outstandingRpcs.remove(resp.requestId);
- listener.onSuccess(resp.response);
+ try {
+ listener.onSuccess(resp.body().nioByteBuffer());
+ } finally {
+ resp.body().release();
+ }
}
} else if (message instanceof RpcFailure) {
RpcFailure resp = (RpcFailure) message;
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java
new file mode 100644
index 0000000000..2924218c2f
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Abstract class for messages which optionally contain a body kept in a separate buffer.
+ */
+public abstract class AbstractMessage implements Message {
+ private final ManagedBuffer body;
+ private final boolean isBodyInFrame;
+
+ protected AbstractMessage() {
+ this(null, false);
+ }
+
+ protected AbstractMessage(ManagedBuffer body, boolean isBodyInFrame) {
+ this.body = body;
+ this.isBodyInFrame = isBodyInFrame;
+ }
+
+ @Override
+ public ManagedBuffer body() {
+ return body;
+ }
+
+ @Override
+ public boolean isBodyInFrame() {
+ return isBodyInFrame;
+ }
+
+ protected boolean equals(AbstractMessage other) {
+ return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body);
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java
index 67be77e39f..c362c92fc4 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java
@@ -17,23 +17,15 @@
package org.apache.spark.network.protocol;
-import com.google.common.base.Objects;
-import io.netty.buffer.ByteBuf;
-
import org.apache.spark.network.buffer.ManagedBuffer;
-import org.apache.spark.network.buffer.NettyManagedBuffer;
/**
- * Abstract class for response messages that contain a large data portion kept in a separate
- * buffer. These messages are treated especially by MessageEncoder.
+ * Abstract class for response messages.
*/
-public abstract class ResponseWithBody implements ResponseMessage {
- public final ManagedBuffer body;
- public final boolean isBodyInFrame;
+public abstract class AbstractResponseMessage extends AbstractMessage implements ResponseMessage {
- protected ResponseWithBody(ManagedBuffer body, boolean isBodyInFrame) {
- this.body = body;
- this.isBodyInFrame = isBodyInFrame;
+ protected AbstractResponseMessage(ManagedBuffer body, boolean isBodyInFrame) {
+ super(body, isBodyInFrame);
}
public abstract ResponseMessage createFailureResponse(String error);
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
index f0363830b6..7b28a9a969 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -23,7 +23,7 @@ import io.netty.buffer.ByteBuf;
/**
* Response to {@link ChunkFetchRequest} when there is an error fetching the chunk.
*/
-public final class ChunkFetchFailure implements ResponseMessage {
+public final class ChunkFetchFailure extends AbstractMessage implements ResponseMessage {
public final StreamChunkId streamChunkId;
public final String errorString;
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
index 5a173af54f..26d063feb5 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
@@ -24,7 +24,7 @@ import io.netty.buffer.ByteBuf;
* Request to fetch a sequence of a single chunk of a stream. This will correspond to a single
* {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
*/
-public final class ChunkFetchRequest implements RequestMessage {
+public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage {
public final StreamChunkId streamChunkId;
public ChunkFetchRequest(StreamChunkId streamChunkId) {
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
index e6a7e9a8b4..94c2ac9b20 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
@@ -30,7 +30,7 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
* may be written by Netty in a more efficient manner (i.e., zero-copy write).
* Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer.
*/
-public final class ChunkFetchSuccess extends ResponseWithBody {
+public final class ChunkFetchSuccess extends AbstractResponseMessage {
public final StreamChunkId streamChunkId;
public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
@@ -67,14 +67,14 @@ public final class ChunkFetchSuccess extends ResponseWithBody {
@Override
public int hashCode() {
- return Objects.hashCode(streamChunkId, body);
+ return Objects.hashCode(streamChunkId, body());
}
@Override
public boolean equals(Object other) {
if (other instanceof ChunkFetchSuccess) {
ChunkFetchSuccess o = (ChunkFetchSuccess) other;
- return streamChunkId.equals(o.streamChunkId) && body.equals(o.body);
+ return streamChunkId.equals(o.streamChunkId) && super.equals(o);
}
return false;
}
@@ -83,7 +83,7 @@ public final class ChunkFetchSuccess extends ResponseWithBody {
public String toString() {
return Objects.toStringHelper(this)
.add("streamChunkId", streamChunkId)
- .add("buffer", body)
+ .add("buffer", body())
.toString();
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
index 39afd03db6..66f5b8b3a5 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
@@ -19,17 +19,25 @@ package org.apache.spark.network.protocol;
import io.netty.buffer.ByteBuf;
+import org.apache.spark.network.buffer.ManagedBuffer;
+
/** An on-the-wire transmittable message. */
public interface Message extends Encodable {
/** Used to identify this request type. */
Type type();
+ /** An optional body for the message. */
+ ManagedBuffer body();
+
+ /** Whether to include the body of the message in the same frame as the message. */
+ boolean isBodyInFrame();
+
/** Preceding every serialized Message is its type, which allows us to deserialize it. */
public static enum Type implements Encodable {
ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
RpcRequest(3), RpcResponse(4), RpcFailure(5),
StreamRequest(6), StreamResponse(7), StreamFailure(8),
- OneWayMessage(9);
+ OneWayMessage(9), User(-1);
private final byte id;
@@ -57,6 +65,7 @@ public interface Message extends Encodable {
case 7: return StreamResponse;
case 8: return StreamFailure;
case 9: return OneWayMessage;
+ case -1: throw new IllegalArgumentException("User type messages cannot be decoded.");
default: throw new IllegalArgumentException("Unknown message type: " + id);
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 6cce97c807..abca22347b 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -42,25 +42,28 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
* data to 'out', in order to enable zero-copy transfer.
*/
@Override
- public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) {
+ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) throws Exception {
Object body = null;
long bodyLength = 0;
boolean isBodyInFrame = false;
- // Detect ResponseWithBody messages and get the data buffer out of them.
- // The body is used in order to enable zero-copy transfer for the payload.
- if (in instanceof ResponseWithBody) {
- ResponseWithBody resp = (ResponseWithBody) in;
+ // If the message has a body, take it out to enable zero-copy transfer for the payload.
+ if (in.body() != null) {
try {
- bodyLength = resp.body.size();
- body = resp.body.convertToNetty();
- isBodyInFrame = resp.isBodyInFrame;
+ bodyLength = in.body().size();
+ body = in.body().convertToNetty();
+ isBodyInFrame = in.isBodyInFrame();
} catch (Exception e) {
- // Re-encode this message as a failure response.
- String error = e.getMessage() != null ? e.getMessage() : "null";
- logger.error(String.format("Error processing %s for client %s",
- resp, ctx.channel().remoteAddress()), e);
- encode(ctx, resp.createFailureResponse(error), out);
+ if (in instanceof AbstractResponseMessage) {
+ AbstractResponseMessage resp = (AbstractResponseMessage) in;
+ // Re-encode this message as a failure response.
+ String error = e.getMessage() != null ? e.getMessage() : "null";
+ logger.error(String.format("Error processing %s for client %s",
+ in, ctx.channel().remoteAddress()), e);
+ encode(ctx, resp.createFailureResponse(error), out);
+ } else {
+ throw e;
+ }
return;
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
index 95a0270be3..efe0470f35 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
@@ -17,21 +17,21 @@
package org.apache.spark.network.protocol;
-import java.util.Arrays;
-
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
/**
* A RPC that does not expect a reply, which is handled by a remote
* {@link org.apache.spark.network.server.RpcHandler}.
*/
-public final class OneWayMessage implements RequestMessage {
- /** Serialized message to send to remote RpcHandler. */
- public final byte[] message;
+public final class OneWayMessage extends AbstractMessage implements RequestMessage {
- public OneWayMessage(byte[] message) {
- this.message = message;
+ public OneWayMessage(ManagedBuffer body) {
+ super(body, true);
}
@Override
@@ -39,29 +39,34 @@ public final class OneWayMessage implements RequestMessage {
@Override
public int encodedLength() {
- return Encoders.ByteArrays.encodedLength(message);
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 4;
}
@Override
public void encode(ByteBuf buf) {
- Encoders.ByteArrays.encode(buf, message);
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
}
public static OneWayMessage decode(ByteBuf buf) {
- byte[] message = Encoders.ByteArrays.decode(buf);
- return new OneWayMessage(message);
+ // See comment in encodedLength().
+ buf.readInt();
+ return new OneWayMessage(new NettyManagedBuffer(buf.retain()));
}
@Override
public int hashCode() {
- return Arrays.hashCode(message);
+ return Objects.hashCode(body());
}
@Override
public boolean equals(Object other) {
if (other instanceof OneWayMessage) {
OneWayMessage o = (OneWayMessage) other;
- return Arrays.equals(message, o.message);
+ return super.equals(o);
}
return false;
}
@@ -69,7 +74,7 @@ public final class OneWayMessage implements RequestMessage {
@Override
public String toString() {
return Objects.toStringHelper(this)
- .add("message", message)
+ .add("body", body())
.toString();
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
index 2dfc7876ba..a76624ef5d 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
@@ -21,7 +21,7 @@ import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
/** Response to {@link RpcRequest} for a failed RPC. */
-public final class RpcFailure implements ResponseMessage {
+public final class RpcFailure extends AbstractMessage implements ResponseMessage {
public final long requestId;
public final String errorString;
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
index 745039db74..96213794a8 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
@@ -17,26 +17,25 @@
package org.apache.spark.network.protocol;
-import java.util.Arrays;
-
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
/**
* A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}.
* This will correspond to a single
* {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
*/
-public final class RpcRequest implements RequestMessage {
+public final class RpcRequest extends AbstractMessage implements RequestMessage {
/** Used to link an RPC request with its response. */
public final long requestId;
- /** Serialized message to send to remote RpcHandler. */
- public final byte[] message;
-
- public RpcRequest(long requestId, byte[] message) {
+ public RpcRequest(long requestId, ManagedBuffer message) {
+ super(message, true);
this.requestId = requestId;
- this.message = message;
}
@Override
@@ -44,31 +43,36 @@ public final class RpcRequest implements RequestMessage {
@Override
public int encodedLength() {
- return 8 + Encoders.ByteArrays.encodedLength(message);
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 8 + 4;
}
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- Encoders.ByteArrays.encode(buf, message);
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
}
public static RpcRequest decode(ByteBuf buf) {
long requestId = buf.readLong();
- byte[] message = Encoders.ByteArrays.decode(buf);
- return new RpcRequest(requestId, message);
+ // See comment in encodedLength().
+ buf.readInt();
+ return new RpcRequest(requestId, new NettyManagedBuffer(buf.retain()));
}
@Override
public int hashCode() {
- return Objects.hashCode(requestId, Arrays.hashCode(message));
+ return Objects.hashCode(requestId, body());
}
@Override
public boolean equals(Object other) {
if (other instanceof RpcRequest) {
RpcRequest o = (RpcRequest) other;
- return requestId == o.requestId && Arrays.equals(message, o.message);
+ return requestId == o.requestId && super.equals(o);
}
return false;
}
@@ -77,7 +81,7 @@ public final class RpcRequest implements RequestMessage {
public String toString() {
return Objects.toStringHelper(this)
.add("requestId", requestId)
- .add("message", message)
+ .add("body", body())
.toString();
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
index 1671cd444f..bae866e14a 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
@@ -17,49 +17,62 @@
package org.apache.spark.network.protocol;
-import java.util.Arrays;
-
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
/** Response to {@link RpcRequest} for a successful RPC. */
-public final class RpcResponse implements ResponseMessage {
+public final class RpcResponse extends AbstractResponseMessage {
public final long requestId;
- public final byte[] response;
- public RpcResponse(long requestId, byte[] response) {
+ public RpcResponse(long requestId, ManagedBuffer message) {
+ super(message, true);
this.requestId = requestId;
- this.response = response;
}
@Override
public Type type() { return Type.RpcResponse; }
@Override
- public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); }
+ public int encodedLength() {
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 8 + 4;
+ }
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- Encoders.ByteArrays.encode(buf, response);
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
+ }
+
+ @Override
+ public ResponseMessage createFailureResponse(String error) {
+ return new RpcFailure(requestId, error);
}
public static RpcResponse decode(ByteBuf buf) {
long requestId = buf.readLong();
- byte[] response = Encoders.ByteArrays.decode(buf);
- return new RpcResponse(requestId, response);
+ // See comment in encodedLength().
+ buf.readInt();
+ return new RpcResponse(requestId, new NettyManagedBuffer(buf.retain()));
}
@Override
public int hashCode() {
- return Objects.hashCode(requestId, Arrays.hashCode(response));
+ return Objects.hashCode(requestId, body());
}
@Override
public boolean equals(Object other) {
if (other instanceof RpcResponse) {
RpcResponse o = (RpcResponse) other;
- return requestId == o.requestId && Arrays.equals(response, o.response);
+ return requestId == o.requestId && super.equals(o);
}
return false;
}
@@ -68,7 +81,7 @@ public final class RpcResponse implements ResponseMessage {
public String toString() {
return Objects.toStringHelper(this)
.add("requestId", requestId)
- .add("response", response)
+ .add("body", body())
.toString();
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
index e3dade2ebf..26747ee55b 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
@@ -26,7 +26,7 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
/**
* Message indicating an error when transferring a stream.
*/
-public final class StreamFailure implements ResponseMessage {
+public final class StreamFailure extends AbstractMessage implements ResponseMessage {
public final String streamId;
public final String error;
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
index 821e8f5388..35af5a84ba 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
@@ -29,7 +29,7 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
* The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before
* the data can be streamed.
*/
-public final class StreamRequest implements RequestMessage {
+public final class StreamRequest extends AbstractMessage implements RequestMessage {
public final String streamId;
public StreamRequest(String streamId) {
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
index ac5ab9a323..51b899930f 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
@@ -30,15 +30,15 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
* sender. The receiver is expected to set a temporary channel handler that will consume the
* number of bytes this message says the stream has.
*/
-public final class StreamResponse extends ResponseWithBody {
- public final String streamId;
- public final long byteCount;
+public final class StreamResponse extends AbstractResponseMessage {
+ public final String streamId;
+ public final long byteCount;
- public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) {
- super(buffer, false);
- this.streamId = streamId;
- this.byteCount = byteCount;
- }
+ public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) {
+ super(buffer, false);
+ this.streamId = streamId;
+ this.byteCount = byteCount;
+ }
@Override
public Type type() { return Type.StreamResponse; }
@@ -68,7 +68,7 @@ public final class StreamResponse extends ResponseWithBody {
@Override
public int hashCode() {
- return Objects.hashCode(byteCount, streamId);
+ return Objects.hashCode(byteCount, streamId, body());
}
@Override
@@ -85,6 +85,7 @@ public final class StreamResponse extends ResponseWithBody {
return Objects.toStringHelper(this)
.add("streamId", streamId)
.add("byteCount", byteCount)
+ .add("body", body())
.toString();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index 69923769d4..68381037d6 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.sasl;
+import java.io.IOException;
+import java.nio.ByteBuffer;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
@@ -28,6 +30,7 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.TransportConf;
/**
@@ -70,11 +73,12 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
while (!saslClient.isComplete()) {
SaslMessage msg = new SaslMessage(appId, payload);
- ByteBuf buf = Unpooled.buffer(msg.encodedLength());
+ ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size());
msg.encode(buf);
+ buf.writeBytes(msg.body().nioByteBuffer());
- byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs());
- payload = saslClient.response(response);
+ ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs());
+ payload = saslClient.response(JavaUtils.bufferToArray(response));
}
client.setClientId(appId);
@@ -88,6 +92,8 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
saslClient = null;
logger.debug("Channel {} configured for SASL encryption.", client);
}
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
} finally {
if (saslClient != null) {
try {
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
index cad76ab7aa..e52b526f09 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
@@ -18,38 +18,50 @@
package org.apache.spark.network.sasl;
import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
-import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
import org.apache.spark.network.protocol.Encoders;
+import org.apache.spark.network.protocol.AbstractMessage;
/**
* Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
* with the given appId. This appId allows a single SaslRpcHandler to multiplex different
* applications which may be using different sets of credentials.
*/
-class SaslMessage implements Encodable {
+class SaslMessage extends AbstractMessage {
/** Serialization tag used to catch incorrect payloads. */
private static final byte TAG_BYTE = (byte) 0xEA;
public final String appId;
- public final byte[] payload;
- public SaslMessage(String appId, byte[] payload) {
+ public SaslMessage(String appId, byte[] message) {
+ this(appId, Unpooled.wrappedBuffer(message));
+ }
+
+ public SaslMessage(String appId, ByteBuf message) {
+ super(new NettyManagedBuffer(message), true);
this.appId = appId;
- this.payload = payload;
}
@Override
+ public Type type() { return Type.User; }
+
+ @Override
public int encodedLength() {
- return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload);
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 1 + Encoders.Strings.encodedLength(appId) + 4;
}
@Override
public void encode(ByteBuf buf) {
buf.writeByte(TAG_BYTE);
Encoders.Strings.encode(buf, appId);
- Encoders.ByteArrays.encode(buf, payload);
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
}
public static SaslMessage decode(ByteBuf buf) {
@@ -59,7 +71,8 @@ class SaslMessage implements Encodable {
}
String appId = Encoders.Strings.decode(buf);
- byte[] payload = Encoders.ByteArrays.decode(buf);
- return new SaslMessage(appId, payload);
+ // See comment in encodedLength().
+ buf.readInt();
+ return new SaslMessage(appId, buf.retain());
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index 830db94b89..c215bd9d15 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -17,8 +17,11 @@
package org.apache.spark.network.sasl;
+import java.io.IOException;
+import java.nio.ByteBuffer;
import javax.security.sasl.Sasl;
+import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import org.slf4j.Logger;
@@ -28,6 +31,7 @@ import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.TransportConf;
/**
@@ -70,14 +74,20 @@ class SaslRpcHandler extends RpcHandler {
}
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
if (isComplete) {
// Authentication complete, delegate to base handler.
delegate.receive(client, message, callback);
return;
}
- SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message));
+ ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
+ SaslMessage saslMessage;
+ try {
+ saslMessage = SaslMessage.decode(nettyBuf);
+ } finally {
+ nettyBuf.release();
+ }
if (saslServer == null) {
// First message in the handshake, setup the necessary state.
@@ -86,8 +96,14 @@ class SaslRpcHandler extends RpcHandler {
conf.saslServerAlwaysEncrypt());
}
- byte[] response = saslServer.response(saslMessage.payload);
- callback.onSuccess(response);
+ byte[] response;
+ try {
+ response = saslServer.response(JavaUtils.bufferToArray(
+ saslMessage.body().nioByteBuffer()));
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
+ }
+ callback.onSuccess(ByteBuffer.wrap(response));
// Setup encryption after the SASL response is sent, otherwise the client can't parse the
// response. It's ok to change the channel pipeline here since we are processing an incoming
@@ -109,7 +125,7 @@ class SaslRpcHandler extends RpcHandler {
}
@Override
- public void receive(TransportClient client, byte[] message) {
+ public void receive(TransportClient client, ByteBuffer message) {
delegate.receive(client, message);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
index b80c15106e..3843406b27 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
@@ -26,7 +26,7 @@ import org.apache.spark.network.protocol.Message;
*/
public abstract class MessageHandler<T extends Message> {
/** Handles the receipt of a single message. */
- public abstract void handle(T message);
+ public abstract void handle(T message) throws Exception;
/** Invoked when an exception was caught on the Channel. */
public abstract void exceptionCaught(Throwable cause);
diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
index 1502b7489e..6ed61da5c7 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
@@ -1,5 +1,3 @@
-package org.apache.spark.network.server;
-
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
@@ -17,6 +15,10 @@ package org.apache.spark.network.server;
* limitations under the License.
*/
+package org.apache.spark.network.server;
+
+import java.nio.ByteBuffer;
+
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
@@ -29,7 +31,7 @@ public class NoOpRpcHandler extends RpcHandler {
}
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
throw new UnsupportedOperationException("Cannot handle messages");
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
index 1a11f7b382..ee1c683699 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.server;
+import java.nio.ByteBuffer;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -44,7 +46,7 @@ public abstract class RpcHandler {
*/
public abstract void receive(
TransportClient client,
- byte[] message,
+ ByteBuffer message,
RpcResponseCallback callback);
/**
@@ -62,7 +64,7 @@ public abstract class RpcHandler {
* of this RPC. This will always be the exact same object for a particular channel.
* @param message The serialized bytes of the RPC.
*/
- public void receive(TransportClient client, byte[] message) {
+ public void receive(TransportClient client, ByteBuffer message) {
receive(client, message, ONE_WAY_CALLBACK);
}
@@ -79,7 +81,7 @@ public abstract class RpcHandler {
private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class);
@Override
- public void onSuccess(byte[] response) {
+ public void onSuccess(ByteBuffer response) {
logger.warn("Response provided for one-way RPC.");
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
index 3164e00679..09435bcbab 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -99,7 +99,7 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
}
@Override
- public void channelRead0(ChannelHandlerContext ctx, Message request) {
+ public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
if (request instanceof RequestMessage) {
requestHandler.handle((RequestMessage) request);
} else {
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index db18ea77d1..c864d7ce16 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.server;
+import java.nio.ByteBuffer;
+
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import io.netty.channel.Channel;
@@ -26,6 +28,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.protocol.ChunkFetchRequest;
@@ -143,10 +146,10 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
private void processRpcRequest(final RpcRequest req) {
try {
- rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() {
+ rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() {
@Override
- public void onSuccess(byte[] response) {
- respond(new RpcResponse(req.requestId, response));
+ public void onSuccess(ByteBuffer response) {
+ respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
}
@Override
@@ -157,14 +160,18 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
} catch (Exception e) {
logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
+ } finally {
+ req.body().release();
}
}
private void processOneWayMessage(OneWayMessage req) {
try {
- rpcHandler.receive(reverseClient, req.message);
+ rpcHandler.receive(reverseClient, req.body().nioByteBuffer());
} catch (Exception e) {
logger.error("Error while invoking RpcHandler#receive() for one-way message.", e);
+ } finally {
+ req.body().release();
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
index 7d27439cfd..b3d8e0cd7c 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -132,7 +132,7 @@ public class JavaUtils {
return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile());
}
- private static final ImmutableMap<String, TimeUnit> timeSuffixes =
+ private static final ImmutableMap<String, TimeUnit> timeSuffixes =
ImmutableMap.<String, TimeUnit>builder()
.put("us", TimeUnit.MICROSECONDS)
.put("ms", TimeUnit.MILLISECONDS)
@@ -164,32 +164,32 @@ public class JavaUtils {
*/
private static long parseTimeString(String str, TimeUnit unit) {
String lower = str.toLowerCase().trim();
-
+
try {
Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower);
if (!m.matches()) {
throw new NumberFormatException("Failed to parse time string: " + str);
}
-
+
long val = Long.parseLong(m.group(1));
String suffix = m.group(2);
-
+
// Check for invalid suffixes
if (suffix != null && !timeSuffixes.containsKey(suffix)) {
throw new NumberFormatException("Invalid suffix: \"" + suffix + "\"");
}
-
+
// If suffix is valid use that, otherwise none was provided and use the default passed
return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit);
} catch (NumberFormatException e) {
String timeError = "Time must be specified as seconds (s), " +
"milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " +
"E.g. 50s, 100ms, or 250us.";
-
+
throw new NumberFormatException(timeError + "\n" + e.getMessage());
}
}
-
+
/**
* Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If
* no suffix is provided, the passed number is assumed to be in ms.
@@ -205,10 +205,10 @@ public class JavaUtils {
public static long timeStringAsSec(String str) {
return parseTimeString(str, TimeUnit.SECONDS);
}
-
+
/**
* Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for
- * internal use. If no suffix is provided a direct conversion of the provided default is
+ * internal use. If no suffix is provided a direct conversion of the provided default is
* attempted.
*/
private static long parseByteString(String str, ByteUnit unit) {
@@ -217,7 +217,7 @@ public class JavaUtils {
try {
Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower);
Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower);
-
+
if (m.matches()) {
long val = Long.parseLong(m.group(1));
String suffix = m.group(2);
@@ -228,14 +228,14 @@ public class JavaUtils {
}
// If suffix is valid use that, otherwise none was provided and use the default passed
- return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit);
+ return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit);
} else if (fractionMatcher.matches()) {
- throw new NumberFormatException("Fractional values are not supported. Input was: "
+ throw new NumberFormatException("Fractional values are not supported. Input was: "
+ fractionMatcher.group(1));
} else {
- throw new NumberFormatException("Failed to parse byte string: " + str);
+ throw new NumberFormatException("Failed to parse byte string: " + str);
}
-
+
} catch (NumberFormatException e) {
String timeError = "Size must be specified as bytes (b), " +
"kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " +
@@ -248,7 +248,7 @@ public class JavaUtils {
/**
* Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for
* internal use.
- *
+ *
* If no suffix is provided, the passed number is assumed to be in bytes.
*/
public static long byteStringAsBytes(String str) {
@@ -264,7 +264,7 @@ public class JavaUtils {
public static long byteStringAsKb(String str) {
return parseByteString(str, ByteUnit.KiB);
}
-
+
/**
* Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for
* internal use.
@@ -284,4 +284,20 @@ public class JavaUtils {
public static long byteStringAsGb(String str) {
return parseByteString(str, ByteUnit.GiB);
}
+
+ /**
+ * Returns a byte array with the buffer's contents, trying to avoid copying the data if
+ * possible.
+ */
+ public static byte[] bufferToArray(ByteBuffer buffer) {
+ if (buffer.hasArray() && buffer.arrayOffset() == 0 &&
+ buffer.array().length == buffer.remaining()) {
+ return buffer.array();
+ } else {
+ byte[] bytes = new byte[buffer.remaining()];
+ buffer.get(bytes);
+ return bytes;
+ }
+ }
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
index 5889562dd9..a466c72915 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
@@ -17,9 +17,13 @@
package org.apache.spark.network.util;
+import java.util.Iterator;
+import java.util.LinkedList;
+
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
+import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
@@ -44,84 +48,138 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
public static final String HANDLER_NAME = "frameDecoder";
private static final int LENGTH_SIZE = 8;
private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE;
+ private static final int UNKNOWN_FRAME_SIZE = -1;
+
+ private final LinkedList<ByteBuf> buffers = new LinkedList<>();
+ private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE);
- private CompositeByteBuf buffer;
+ private long totalSize = 0;
+ private long nextFrameSize = UNKNOWN_FRAME_SIZE;
private volatile Interceptor interceptor;
@Override
public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
ByteBuf in = (ByteBuf) data;
+ buffers.add(in);
+ totalSize += in.readableBytes();
+
+ while (!buffers.isEmpty()) {
+ // First, feed the interceptor, and if it's still, active, try again.
+ if (interceptor != null) {
+ ByteBuf first = buffers.getFirst();
+ int available = first.readableBytes();
+ if (feedInterceptor(first)) {
+ assert !first.isReadable() : "Interceptor still active but buffer has data.";
+ }
- if (buffer == null) {
- buffer = in.alloc().compositeBuffer();
- }
-
- buffer.addComponent(in).writerIndex(buffer.writerIndex() + in.readableBytes());
-
- while (buffer.isReadable()) {
- discardReadBytes();
- if (!feedInterceptor()) {
+ int read = available - first.readableBytes();
+ if (read == available) {
+ buffers.removeFirst().release();
+ }
+ totalSize -= read;
+ } else {
+ // Interceptor is not active, so try to decode one frame.
ByteBuf frame = decodeNext();
if (frame == null) {
break;
}
-
ctx.fireChannelRead(frame);
}
}
-
- discardReadBytes();
}
- private void discardReadBytes() {
- // If the buffer's been retained by downstream code, then make a copy of the remaining
- // bytes into a new buffer. Otherwise, just discard stale components.
- if (buffer.refCnt() > 1) {
- CompositeByteBuf newBuffer = buffer.alloc().compositeBuffer();
+ private long decodeFrameSize() {
+ if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < LENGTH_SIZE) {
+ return nextFrameSize;
+ }
- if (buffer.readableBytes() > 0) {
- ByteBuf spillBuf = buffer.alloc().buffer(buffer.readableBytes());
- spillBuf.writeBytes(buffer);
- newBuffer.addComponent(spillBuf).writerIndex(spillBuf.readableBytes());
+ // We know there's enough data. If the first buffer contains all the data, great. Otherwise,
+ // hold the bytes for the frame length in a composite buffer until we have enough data to read
+ // the frame size. Normally, it should be rare to need more than one buffer to read the frame
+ // size.
+ ByteBuf first = buffers.getFirst();
+ if (first.readableBytes() >= LENGTH_SIZE) {
+ nextFrameSize = first.readLong() - LENGTH_SIZE;
+ totalSize -= LENGTH_SIZE;
+ if (!first.isReadable()) {
+ buffers.removeFirst().release();
}
+ return nextFrameSize;
+ }
- buffer.release();
- buffer = newBuffer;
- } else {
- buffer.discardReadComponents();
+ while (frameLenBuf.readableBytes() < LENGTH_SIZE) {
+ ByteBuf next = buffers.getFirst();
+ int toRead = Math.min(next.readableBytes(), LENGTH_SIZE - frameLenBuf.readableBytes());
+ frameLenBuf.writeBytes(next, toRead);
+ if (!next.isReadable()) {
+ buffers.removeFirst().release();
+ }
}
+
+ nextFrameSize = frameLenBuf.readLong() - LENGTH_SIZE;
+ totalSize -= LENGTH_SIZE;
+ frameLenBuf.clear();
+ return nextFrameSize;
}
private ByteBuf decodeNext() throws Exception {
- if (buffer.readableBytes() < LENGTH_SIZE) {
+ long frameSize = decodeFrameSize();
+ if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) {
return null;
}
- int frameLen = (int) buffer.readLong() - LENGTH_SIZE;
- if (buffer.readableBytes() < frameLen) {
- buffer.readerIndex(buffer.readerIndex() - LENGTH_SIZE);
- return null;
+ // Reset size for next frame.
+ nextFrameSize = UNKNOWN_FRAME_SIZE;
+
+ Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize);
+ Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize);
+
+ // If the first buffer holds the entire frame, return it.
+ int remaining = (int) frameSize;
+ if (buffers.getFirst().readableBytes() >= remaining) {
+ return nextBufferForFrame(remaining);
}
- Preconditions.checkArgument(frameLen < MAX_FRAME_SIZE, "Too large frame: %s", frameLen);
- Preconditions.checkArgument(frameLen > 0, "Frame length should be positive: %s", frameLen);
+ // Otherwise, create a composite buffer.
+ CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer();
+ while (remaining > 0) {
+ ByteBuf next = nextBufferForFrame(remaining);
+ remaining -= next.readableBytes();
+ frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes());
+ }
+ assert remaining == 0;
+ return frame;
+ }
+
+ /**
+ * Takes the first buffer in the internal list, and either adjust it to fit in the frame
+ * (by taking a slice out of it) or remove it from the internal list.
+ */
+ private ByteBuf nextBufferForFrame(int bytesToRead) {
+ ByteBuf buf = buffers.getFirst();
+ ByteBuf frame;
+
+ if (buf.readableBytes() > bytesToRead) {
+ frame = buf.retain().readSlice(bytesToRead);
+ totalSize -= bytesToRead;
+ } else {
+ frame = buf;
+ buffers.removeFirst();
+ totalSize -= frame.readableBytes();
+ }
- ByteBuf frame = buffer.readSlice(frameLen);
- frame.retain();
return frame;
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
- if (buffer != null) {
- if (buffer.isReadable()) {
- feedInterceptor();
- }
- buffer.release();
+ for (ByteBuf b : buffers) {
+ b.release();
}
if (interceptor != null) {
interceptor.channelInactive();
}
+ frameLenBuf.release();
super.channelInactive(ctx);
}
@@ -141,8 +199,8 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
/**
* @return Whether the interceptor is still active after processing the data.
*/
- private boolean feedInterceptor() throws Exception {
- if (interceptor != null && !interceptor.handle(buffer)) {
+ private boolean feedInterceptor(ByteBuf buf) throws Exception {
+ if (interceptor != null && !interceptor.handle(buf)) {
interceptor = null;
}
return interceptor != null;
diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
index 50a324e293..70c849d60e 100644
--- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
@@ -107,7 +107,10 @@ public class ChunkFetchIntegrationSuite {
};
RpcHandler handler = new RpcHandler() {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
throw new UnsupportedOperationException();
}
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index 1aa20900ff..6c8dd742f4 100644
--- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -82,10 +82,10 @@ public class ProtocolSuite {
@Test
public void requests() {
testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2)));
- testClientToServer(new RpcRequest(12345, new byte[0]));
- testClientToServer(new RpcRequest(12345, new byte[100]));
+ testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0)));
+ testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10)));
testClientToServer(new StreamRequest("abcde"));
- testClientToServer(new OneWayMessage(new byte[100]));
+ testClientToServer(new OneWayMessage(new TestManagedBuffer(10)));
}
@Test
@@ -94,8 +94,8 @@ public class ProtocolSuite {
testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0)));
testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error"));
testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), ""));
- testServerToClient(new RpcResponse(12345, new byte[0]));
- testServerToClient(new RpcResponse(12345, new byte[1000]));
+ testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0)));
+ testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100)));
testServerToClient(new RpcFailure(0, "this is an error"));
testServerToClient(new RpcFailure(0, ""));
// Note: buffer size must be "0" since StreamResponse's buffer is written differently to the
diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
index 42955ef692..f9b5bf96d6 100644
--- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
@@ -31,6 +31,7 @@ import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.junit.*;
+import static org.junit.Assert.*;
import java.io.IOException;
import java.nio.ByteBuffer;
@@ -84,13 +85,16 @@ public class RequestTimeoutIntegrationSuite {
@Test
public void timeoutInactiveRequests() throws Exception {
final Semaphore semaphore = new Semaphore(1);
- final byte[] response = new byte[16];
+ final int responseSize = 16;
RpcHandler handler = new RpcHandler() {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
try {
semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
- callback.onSuccess(response);
+ callback.onSuccess(ByteBuffer.allocate(responseSize));
} catch (InterruptedException e) {
// do nothing
}
@@ -110,15 +114,15 @@ public class RequestTimeoutIntegrationSuite {
// First completes quickly (semaphore starts at 1).
TestCallback callback0 = new TestCallback();
synchronized (callback0) {
- client.sendRpc(new byte[0], callback0);
+ client.sendRpc(ByteBuffer.allocate(0), callback0);
callback0.wait(FOREVER);
- assert (callback0.success.length == response.length);
+ assertEquals(responseSize, callback0.successLength);
}
// Second times out after 2 seconds, with slack. Must be IOException.
TestCallback callback1 = new TestCallback();
synchronized (callback1) {
- client.sendRpc(new byte[0], callback1);
+ client.sendRpc(ByteBuffer.allocate(0), callback1);
callback1.wait(4 * 1000);
assert (callback1.failure != null);
assert (callback1.failure instanceof IOException);
@@ -131,13 +135,16 @@ public class RequestTimeoutIntegrationSuite {
@Test
public void timeoutCleanlyClosesClient() throws Exception {
final Semaphore semaphore = new Semaphore(0);
- final byte[] response = new byte[16];
+ final int responseSize = 16;
RpcHandler handler = new RpcHandler() {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
try {
semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
- callback.onSuccess(response);
+ callback.onSuccess(ByteBuffer.allocate(responseSize));
} catch (InterruptedException e) {
// do nothing
}
@@ -158,7 +165,7 @@ public class RequestTimeoutIntegrationSuite {
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
TestCallback callback0 = new TestCallback();
synchronized (callback0) {
- client0.sendRpc(new byte[0], callback0);
+ client0.sendRpc(ByteBuffer.allocate(0), callback0);
callback0.wait(FOREVER);
assert (callback0.failure instanceof IOException);
assert (!client0.isActive());
@@ -170,10 +177,10 @@ public class RequestTimeoutIntegrationSuite {
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
TestCallback callback1 = new TestCallback();
synchronized (callback1) {
- client1.sendRpc(new byte[0], callback1);
+ client1.sendRpc(ByteBuffer.allocate(0), callback1);
callback1.wait(FOREVER);
- assert (callback1.success.length == response.length);
- assert (callback1.failure == null);
+ assertEquals(responseSize, callback1.successLength);
+ assertNull(callback1.failure);
}
}
@@ -191,7 +198,10 @@ public class RequestTimeoutIntegrationSuite {
};
RpcHandler handler = new RpcHandler() {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
throw new UnsupportedOperationException();
}
@@ -218,9 +228,10 @@ public class RequestTimeoutIntegrationSuite {
synchronized (callback0) {
// not complete yet, but should complete soon
- assert (callback0.success == null && callback0.failure == null);
+ assertEquals(-1, callback0.successLength);
+ assertNull(callback0.failure);
callback0.wait(2 * 1000);
- assert (callback0.failure instanceof IOException);
+ assertTrue(callback0.failure instanceof IOException);
}
synchronized (callback1) {
@@ -235,13 +246,13 @@ public class RequestTimeoutIntegrationSuite {
*/
class TestCallback implements RpcResponseCallback, ChunkReceivedCallback {
- byte[] success;
+ int successLength = -1;
Throwable failure;
@Override
- public void onSuccess(byte[] response) {
+ public void onSuccess(ByteBuffer response) {
synchronized(this) {
- success = response;
+ successLength = response.remaining();
this.notifyAll();
}
}
@@ -258,7 +269,7 @@ public class RequestTimeoutIntegrationSuite {
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
synchronized(this) {
try {
- success = buffer.nioByteBuffer().array();
+ successLength = buffer.nioByteBuffer().remaining();
this.notifyAll();
} catch (IOException e) {
// weird
diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
index 88fa2258bb..9e9be98c14 100644
--- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.network;
+import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
@@ -26,7 +27,6 @@ import java.util.Set;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
-import com.google.common.base.Charsets;
import com.google.common.collect.Sets;
import org.junit.AfterClass;
import org.junit.BeforeClass;
@@ -41,6 +41,7 @@ import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
@@ -55,11 +56,14 @@ public class RpcIntegrationSuite {
TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
rpcHandler = new RpcHandler() {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
- String msg = new String(message, Charsets.UTF_8);
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ String msg = JavaUtils.bytesToString(message);
String[] parts = msg.split("/");
if (parts[0].equals("hello")) {
- callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8));
+ callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!"));
} else if (parts[0].equals("return error")) {
callback.onFailure(new RuntimeException("Returned: " + parts[1]));
} else if (parts[0].equals("throw error")) {
@@ -68,9 +72,8 @@ public class RpcIntegrationSuite {
}
@Override
- public void receive(TransportClient client, byte[] message) {
- String msg = new String(message, Charsets.UTF_8);
- oneWayMsgs.add(msg);
+ public void receive(TransportClient client, ByteBuffer message) {
+ oneWayMsgs.add(JavaUtils.bytesToString(message));
}
@Override
@@ -103,8 +106,9 @@ public class RpcIntegrationSuite {
RpcResponseCallback callback = new RpcResponseCallback() {
@Override
- public void onSuccess(byte[] message) {
- res.successMessages.add(new String(message, Charsets.UTF_8));
+ public void onSuccess(ByteBuffer message) {
+ String response = JavaUtils.bytesToString(message);
+ res.successMessages.add(response);
sem.release();
}
@@ -116,7 +120,7 @@ public class RpcIntegrationSuite {
};
for (String command : commands) {
- client.sendRpc(command.getBytes(Charsets.UTF_8), callback);
+ client.sendRpc(JavaUtils.stringToBytes(command), callback);
}
if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) {
@@ -173,7 +177,7 @@ public class RpcIntegrationSuite {
final String message = "no reply";
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
try {
- client.send(message.getBytes(Charsets.UTF_8));
+ client.send(JavaUtils.stringToBytes(message));
assertEquals(0, client.getHandler().numOutstandingRequests());
// Make sure the message arrives.
diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
index 538f3efe8d..9c49556927 100644
--- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
@@ -116,7 +116,10 @@ public class StreamSuite {
};
RpcHandler handler = new RpcHandler() {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
throw new UnsupportedOperationException();
}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
index 30144f4a9f..128f7cba74 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
@@ -17,6 +17,8 @@
package org.apache.spark.network;
+import java.nio.ByteBuffer;
+
import io.netty.channel.Channel;
import io.netty.channel.local.LocalChannel;
import org.junit.Test;
@@ -27,6 +29,7 @@ import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.*;
import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallback;
@@ -42,7 +45,7 @@ import org.apache.spark.network.util.TransportFrameDecoder;
public class TransportResponseHandlerSuite {
@Test
- public void handleSuccessfulFetch() {
+ public void handleSuccessfulFetch() throws Exception {
StreamChunkId streamChunkId = new StreamChunkId(1, 0);
TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
@@ -56,7 +59,7 @@ public class TransportResponseHandlerSuite {
}
@Test
- public void handleFailedFetch() {
+ public void handleFailedFetch() throws Exception {
StreamChunkId streamChunkId = new StreamChunkId(1, 0);
TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
@@ -69,7 +72,7 @@ public class TransportResponseHandlerSuite {
}
@Test
- public void clearAllOutstandingRequests() {
+ public void clearAllOutstandingRequests() throws Exception {
TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
handler.addFetchRequest(new StreamChunkId(1, 0), callback);
@@ -88,23 +91,24 @@ public class TransportResponseHandlerSuite {
}
@Test
- public void handleSuccessfulRPC() {
+ public void handleSuccessfulRPC() throws Exception {
TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
RpcResponseCallback callback = mock(RpcResponseCallback.class);
handler.addRpcRequest(12345, callback);
assertEquals(1, handler.numOutstandingRequests());
- handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored
+ // This response should be ignored.
+ handler.handle(new RpcResponse(54321, new NioManagedBuffer(ByteBuffer.allocate(7))));
assertEquals(1, handler.numOutstandingRequests());
- byte[] arr = new byte[10];
- handler.handle(new RpcResponse(12345, arr));
- verify(callback, times(1)).onSuccess(eq(arr));
+ ByteBuffer resp = ByteBuffer.allocate(10);
+ handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp)));
+ verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10)));
assertEquals(0, handler.numOutstandingRequests());
}
@Test
- public void handleFailedRPC() {
+ public void handleFailedRPC() throws Exception {
TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
RpcResponseCallback callback = mock(RpcResponseCallback.class);
handler.addRpcRequest(12345, callback);
@@ -119,7 +123,7 @@ public class TransportResponseHandlerSuite {
}
@Test
- public void testActiveStreams() {
+ public void testActiveStreams() throws Exception {
Channel c = new LocalChannel();
c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
TransportResponseHandler handler = new TransportResponseHandler(c);
diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index a6f180bc40..751516b9d8 100644
--- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -22,7 +22,7 @@ import static org.mockito.Mockito.*;
import java.io.File;
import java.lang.reflect.Method;
-import java.nio.charset.StandardCharsets;
+import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
@@ -57,6 +57,7 @@ import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
@@ -123,39 +124,53 @@ public class SparkSaslSuite {
}
@Test
- public void testSaslAuthentication() throws Exception {
+ public void testSaslAuthentication() throws Throwable {
testBasicSasl(false);
}
@Test
- public void testSaslEncryption() throws Exception {
+ public void testSaslEncryption() throws Throwable {
testBasicSasl(true);
}
- private void testBasicSasl(boolean encrypt) throws Exception {
+ private void testBasicSasl(boolean encrypt) throws Throwable {
RpcHandler rpcHandler = mock(RpcHandler.class);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) {
- byte[] message = (byte[]) invocation.getArguments()[1];
+ ByteBuffer message = (ByteBuffer) invocation.getArguments()[1];
RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2];
- assertEquals("Ping", new String(message, StandardCharsets.UTF_8));
- cb.onSuccess("Pong".getBytes(StandardCharsets.UTF_8));
+ assertEquals("Ping", JavaUtils.bytesToString(message));
+ cb.onSuccess(JavaUtils.stringToBytes("Pong"));
return null;
}
})
.when(rpcHandler)
- .receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class));
+ .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class));
SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
try {
- byte[] response = ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8),
- TimeUnit.SECONDS.toMillis(10));
- assertEquals("Pong", new String(response, StandardCharsets.UTF_8));
+ ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
+ TimeUnit.SECONDS.toMillis(10));
+ assertEquals("Pong", JavaUtils.bytesToString(response));
} finally {
ctx.close();
// There should be 2 terminated events; one for the client, one for the server.
- verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class));
+ Throwable error = null;
+ long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
+ while (deadline > System.nanoTime()) {
+ try {
+ verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class));
+ error = null;
+ break;
+ } catch (Throwable t) {
+ error = t;
+ TimeUnit.MILLISECONDS.sleep(10);
+ }
+ }
+ if (error != null) {
+ throw error;
+ }
}
}
@@ -325,8 +340,8 @@ public class SparkSaslSuite {
SaslTestCtx ctx = null;
try {
ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
- ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8),
- TimeUnit.SECONDS.toMillis(10));
+ ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
+ TimeUnit.SECONDS.toMillis(10));
fail("Should have failed to send RPC to server.");
} catch (Exception e) {
assertFalse(e.getCause() instanceof TimeoutException);
diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
index 19475c21ff..d4de4a941d 100644
--- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
@@ -118,6 +118,27 @@ public class TransportFrameDecoderSuite {
}
}
+ @Test
+ public void testSplitLengthField() throws Exception {
+ byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
+ ByteBuf buf = Unpooled.buffer(frame.length + 8);
+ buf.writeLong(frame.length + 8);
+ buf.writeBytes(frame);
+
+ TransportFrameDecoder decoder = new TransportFrameDecoder();
+ ChannelHandlerContext ctx = mockChannelHandlerContext();
+ try {
+ decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain());
+ verify(ctx, never()).fireChannelRead(any(ByteBuf.class));
+ decoder.channelRead(ctx, buf);
+ verify(ctx).fireChannelRead(any(ByteBuf.class));
+ assertEquals(0, buf.refCnt());
+ } finally {
+ decoder.channelInactive(ctx);
+ release(buf);
+ }
+ }
+
@Test(expected = IllegalArgumentException.class)
public void testNegativeFrameSize() throws Exception {
testInvalidFrame(-1);
@@ -183,7 +204,7 @@ public class TransportFrameDecoderSuite {
try {
decoder.channelRead(ctx, frame);
} finally {
- frame.release();
+ release(frame);
}
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 3ddf5c3c39..f22187a01d 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -19,6 +19,7 @@ package org.apache.spark.network.shuffle;
import java.io.File;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.List;
import com.google.common.annotations.VisibleForTesting;
@@ -66,8 +67,8 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
}
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
- BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message);
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
+ BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message);
handleMessage(msgObj, client, callback);
}
@@ -85,13 +86,13 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
}
long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator());
logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
- callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray());
+ callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer());
} else if (msgObj instanceof RegisterExecutor) {
RegisterExecutor msg = (RegisterExecutor) msgObj;
checkAuth(client, msg.appId);
blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
- callback.onSuccess(new byte[0]);
+ callback.onSuccess(ByteBuffer.wrap(new byte[0]));
} else {
throw new UnsupportedOperationException("Unexpected message: " + msgObj);
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index ef3a9dcc87..58ca87d9d3 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -18,6 +18,7 @@
package org.apache.spark.network.shuffle;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.List;
import com.google.common.base.Preconditions;
@@ -139,7 +140,7 @@ public class ExternalShuffleClient extends ShuffleClient {
checkInit();
TransportClient client = clientFactory.createUnmanagedClient(host, port);
try {
- byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray();
+ ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer();
client.sendRpcSync(registerMessage, 5000 /* timeoutMs */);
} finally {
client.close();
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index e653f5cb14..1b2ddbf1ed 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -17,6 +17,7 @@
package org.apache.spark.network.shuffle;
+import java.nio.ByteBuffer;
import java.util.Arrays;
import org.slf4j.Logger;
@@ -89,11 +90,11 @@ public class OneForOneBlockFetcher {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
- client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() {
+ client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
@Override
- public void onSuccess(byte[] response) {
+ public void onSuccess(ByteBuffer response) {
try {
- streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);
+ streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);
// Immediately request all chunks -- we expect that the total size of the request is
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
index 7543b6be4f..675820308b 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
@@ -18,6 +18,7 @@
package org.apache.spark.network.shuffle.mesos;
import java.io.IOException;
+import java.nio.ByteBuffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -54,11 +55,11 @@ public class MesosExternalShuffleClient extends ExternalShuffleClient {
public void registerDriverWithShuffleService(String host, int port) throws IOException {
checkInit();
- byte[] registerDriver = new RegisterDriver(appId).toByteArray();
+ ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer();
TransportClient client = clientFactory.createClient(host, port);
client.sendRpc(registerDriver, new RpcResponseCallback() {
@Override
- public void onSuccess(byte[] response) {
+ public void onSuccess(ByteBuffer response) {
logger.info("Successfully registered app " + appId + " with external shuffle service.");
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
index fcb52363e6..7fbe3384b4 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.shuffle.protocol;
+import java.nio.ByteBuffer;
+
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
@@ -53,7 +55,7 @@ public abstract class BlockTransferMessage implements Encodable {
// NB: Java does not support static methods in interfaces, so we must put this in a static class.
public static class Decoder {
/** Deserializes the 'type' byte followed by the message itself. */
- public static BlockTransferMessage fromByteArray(byte[] msg) {
+ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) {
ByteBuf buf = Unpooled.wrappedBuffer(msg);
byte type = buf.readByte();
switch (type) {
@@ -68,12 +70,12 @@ public abstract class BlockTransferMessage implements Encodable {
}
/** Serializes the 'type' byte followed by the message itself. */
- public byte[] toByteArray() {
+ public ByteBuffer toByteBuffer() {
// Allow room for encoded message, plus the type byte
ByteBuf buf = Unpooled.buffer(encodedLength() + 1);
buf.writeByte(type().id);
encode(buf);
assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes();
- return buf.array();
+ return buf.nioBuffer();
}
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index 1c2fa4d0d4..19c870aebb 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.network.sasl;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicReference;
@@ -52,6 +53,7 @@ import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
@@ -107,8 +109,8 @@ public class SaslIntegrationSuite {
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
String msg = "Hello, World!";
- byte[] resp = client.sendRpcSync(msg.getBytes(), TIMEOUT_MS);
- assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg
+ ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), TIMEOUT_MS);
+ assertEquals(msg, JavaUtils.bytesToString(resp));
}
@Test
@@ -136,7 +138,7 @@ public class SaslIntegrationSuite {
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
try {
- client.sendRpcSync(new byte[13], TIMEOUT_MS);
+ client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS);
fail("Should have failed");
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage"));
@@ -144,7 +146,7 @@ public class SaslIntegrationSuite {
try {
// Guessing the right tag byte doesn't magically get you in...
- client.sendRpcSync(new byte[] { (byte) 0xEA }, TIMEOUT_MS);
+ client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS);
fail("Should have failed");
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException"));
@@ -222,13 +224,13 @@ public class SaslIntegrationSuite {
new String[] { System.getProperty("java.io.tmpdir") }, 1,
"org.apache.spark.shuffle.sort.SortShuffleManager");
RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
- client1.sendRpcSync(regmsg.toByteArray(), TIMEOUT_MS);
+ client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
// Make a successful request to fetch blocks, which creates a new stream. But do not actually
// fetch any blocks, to keep the stream open.
OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
- byte[] response = client1.sendRpcSync(openMessage.toByteArray(), TIMEOUT_MS);
- StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);
+ ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS);
+ StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
long streamId = stream.streamId;
// Create a second client, authenticated with a different app ID, and try to read from
@@ -275,7 +277,7 @@ public class SaslIntegrationSuite {
/** RPC handler which simply responds with the message it received. */
public static class TestRpcHandler extends RpcHandler {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
callback.onSuccess(message);
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
index d65de9ca55..86c8609e70 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
@@ -36,7 +36,7 @@ public class BlockTransferMessagesSuite {
}
private void checkSerializeDeserialize(BlockTransferMessage msg) {
- BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray());
+ BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer());
assertEquals(msg, msg2);
assertEquals(msg.hashCode(), msg2.hashCode());
assertEquals(msg.toString(), msg2.toString());
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index e61390cf57..9379412155 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -60,12 +60,12 @@ public class ExternalShuffleBlockHandlerSuite {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort");
- byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray();
+ ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer();
handler.receive(client, registerMessage, callback);
verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config);
- verify(callback, times(1)).onSuccess((byte[]) any());
- verify(callback, never()).onFailure((Throwable) any());
+ verify(callback, times(1)).onSuccess(any(ByteBuffer.class));
+ verify(callback, never()).onFailure(any(Throwable.class));
}
@SuppressWarnings("unchecked")
@@ -77,17 +77,18 @@ public class ExternalShuffleBlockHandlerSuite {
ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker);
when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker);
- byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray();
+ ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" })
+ .toByteBuffer();
handler.receive(client, openBlocks, callback);
verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0");
verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1");
- ArgumentCaptor<byte[]> response = ArgumentCaptor.forClass(byte[].class);
+ ArgumentCaptor<ByteBuffer> response = ArgumentCaptor.forClass(ByteBuffer.class);
verify(callback, times(1)).onSuccess(response.capture());
verify(callback, never()).onFailure((Throwable) any());
StreamHandle handle =
- (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue());
+ (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue());
assertEquals(2, handle.numChunks);
@SuppressWarnings("unchecked")
@@ -104,7 +105,7 @@ public class ExternalShuffleBlockHandlerSuite {
public void testBadMessages() {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
- byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 };
+ ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 });
try {
handler.receive(client, unserializableMsg, callback);
fail("Should have thrown");
@@ -112,7 +113,7 @@ public class ExternalShuffleBlockHandlerSuite {
// pass
}
- byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray();
+ ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer();
try {
handler.receive(client, unexpectedMsg, callback);
fail("Should have thrown");
@@ -120,7 +121,7 @@ public class ExternalShuffleBlockHandlerSuite {
// pass
}
- verify(callback, never()).onSuccess((byte[]) any());
- verify(callback, never()).onFailure((Throwable) any());
+ verify(callback, never()).onSuccess(any(ByteBuffer.class));
+ verify(callback, never()).onFailure(any(Throwable.class));
}
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
index b35a6d685d..2590b9ce4c 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
@@ -134,14 +134,14 @@ public class OneForOneBlockFetcherSuite {
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
- BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray(
- (byte[]) invocationOnMock.getArguments()[0]);
+ BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer(
+ (ByteBuffer) invocationOnMock.getArguments()[0]);
RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1];
- callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray());
+ callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer());
assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message);
return null;
}
- }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any());
+ }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class));
// Respond to each chunk request with a single buffer from our blocks array.
final AtomicInteger expectedChunkIndex = new AtomicInteger(0);