aboutsummaryrefslogtreecommitdiff
path: root/network
diff options
context:
space:
mode:
authorAaron Davidson <aaron@databricks.com>2014-11-07 09:42:21 -0800
committerReynold Xin <rxin@databricks.com>2014-11-07 09:42:21 -0800
commitd4fa04e50d299e9cad349b3781772956453a696b (patch)
treed561c5ac49c49b56aba2fa7964587a9e72c253a9 /network
parent3abdb1b24aa48f21e7eed1232c01d3933873688c (diff)
downloadspark-d4fa04e50d299e9cad349b3781772956453a696b.tar.gz
spark-d4fa04e50d299e9cad349b3781772956453a696b.tar.bz2
spark-d4fa04e50d299e9cad349b3781772956453a696b.zip
[SPARK-4187] [Core] Switch to binary protocol for external shuffle service messages
This PR elimiantes the network package's usage of the Java serializer and replaces it with Encodable, which is a lightweight binary protocol. Each message is preceded by a type id, which will allow us to change messages (by only adding new ones), or to change the format entirely by switching to a special id (such as -1). This protocol has the advantage over Java that we can guarantee that messages will remain compatible across compiled versions and JVMs, though it does not provide a clean way to do schema migration. In the future, it may be good to use a more heavy-weight serialization format like protobuf, thrift, or avro, but these all add several dependencies which are unnecessary at the present time. Additionally this unifies the RPC messages of NettyBlockTransferService and ExternalShuffleClient. Author: Aaron Davidson <aaron@databricks.com> Closes #3146 from aarondav/free and squashes the following commits: ed1102a [Aaron Davidson] Remove some unused imports b8e2a49 [Aaron Davidson] Add appId to test 538f2a3 [Aaron Davidson] [SPARK-4187] [Core] Switch to binary protocol for external shuffle service messages
Diffstat (limited to 'network')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java12
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java93
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java12
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java27
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java24
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java21
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java1
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java12
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java106
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java17
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java76
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java (renamed from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java)36
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java87
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java91
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java (renamed from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java)34
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java113
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java (renamed from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java)33
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java29
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java1
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java1
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java18
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java2
24 files changed, 608 insertions, 256 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
index 152af98ced..986957c150 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -38,23 +38,19 @@ public final class ChunkFetchFailure implements ResponseMessage {
@Override
public int encodedLength() {
- return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString);
}
@Override
public void encode(ByteBuf buf) {
streamChunkId.encode(buf);
- byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
- buf.writeInt(errorBytes.length);
- buf.writeBytes(errorBytes);
+ Encoders.Strings.encode(buf, errorString);
}
public static ChunkFetchFailure decode(ByteBuf buf) {
StreamChunkId streamChunkId = StreamChunkId.decode(buf);
- int numErrorStringBytes = buf.readInt();
- byte[] errorBytes = new byte[numErrorStringBytes];
- buf.readBytes(errorBytes);
- return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8));
+ String errorString = Encoders.Strings.decode(buf);
+ return new ChunkFetchFailure(streamChunkId, errorString);
}
@Override
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java
new file mode 100644
index 0000000000..873c694250
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+
+import com.google.common.base.Charsets;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+/** Provides a canonical set of Encoders for simple types. */
+public class Encoders {
+
+ /** Strings are encoded with their length followed by UTF-8 bytes. */
+ public static class Strings {
+ public static int encodedLength(String s) {
+ return 4 + s.getBytes(Charsets.UTF_8).length;
+ }
+
+ public static void encode(ByteBuf buf, String s) {
+ byte[] bytes = s.getBytes(Charsets.UTF_8);
+ buf.writeInt(bytes.length);
+ buf.writeBytes(bytes);
+ }
+
+ public static String decode(ByteBuf buf) {
+ int length = buf.readInt();
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return new String(bytes, Charsets.UTF_8);
+ }
+ }
+
+ /** Byte arrays are encoded with their length followed by bytes. */
+ public static class ByteArrays {
+ public static int encodedLength(byte[] arr) {
+ return 4 + arr.length;
+ }
+
+ public static void encode(ByteBuf buf, byte[] arr) {
+ buf.writeInt(arr.length);
+ buf.writeBytes(arr);
+ }
+
+ public static byte[] decode(ByteBuf buf) {
+ int length = buf.readInt();
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return bytes;
+ }
+ }
+
+ /** String arrays are encoded with the number of strings followed by per-String encoding. */
+ public static class StringArrays {
+ public static int encodedLength(String[] strings) {
+ int totalLength = 4;
+ for (String s : strings) {
+ totalLength += Strings.encodedLength(s);
+ }
+ return totalLength;
+ }
+
+ public static void encode(ByteBuf buf, String[] strings) {
+ buf.writeInt(strings.length);
+ for (String s : strings) {
+ Strings.encode(buf, s);
+ }
+ }
+
+ public static String[] decode(ByteBuf buf) {
+ int numStrings = buf.readInt();
+ String[] strings = new String[numStrings];
+ for (int i = 0; i < strings.length; i ++) {
+ strings[i] = Strings.decode(buf);
+ }
+ return strings;
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
index e239d4ffbd..ebd764eb5e 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
@@ -36,23 +36,19 @@ public final class RpcFailure implements ResponseMessage {
@Override
public int encodedLength() {
- return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ return 8 + Encoders.Strings.encodedLength(errorString);
}
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
- buf.writeInt(errorBytes.length);
- buf.writeBytes(errorBytes);
+ Encoders.Strings.encode(buf, errorString);
}
public static RpcFailure decode(ByteBuf buf) {
long requestId = buf.readLong();
- int numErrorStringBytes = buf.readInt();
- byte[] errorBytes = new byte[numErrorStringBytes];
- buf.readBytes(errorBytes);
- return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8));
+ String errorString = Encoders.Strings.decode(buf);
+ return new RpcFailure(requestId, errorString);
}
@Override
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
index 099e934ae0..cdee0b0e03 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
@@ -44,21 +44,18 @@ public final class RpcRequest implements RequestMessage {
@Override
public int encodedLength() {
- return 8 + 4 + message.length;
+ return 8 + Encoders.ByteArrays.encodedLength(message);
}
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- buf.writeInt(message.length);
- buf.writeBytes(message);
+ Encoders.ByteArrays.encode(buf, message);
}
public static RpcRequest decode(ByteBuf buf) {
long requestId = buf.readLong();
- int messageLen = buf.readInt();
- byte[] message = new byte[messageLen];
- buf.readBytes(message);
+ byte[] message = Encoders.ByteArrays.decode(buf);
return new RpcRequest(requestId, message);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
index ed47947832..0a62e09a81 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
@@ -36,20 +36,17 @@ public final class RpcResponse implements ResponseMessage {
public Type type() { return Type.RpcResponse; }
@Override
- public int encodedLength() { return 8 + 4 + response.length; }
+ public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); }
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- buf.writeInt(response.length);
- buf.writeBytes(response);
+ Encoders.ByteArrays.encode(buf, response);
}
public static RpcResponse decode(ByteBuf buf) {
long requestId = buf.readLong();
- int responseLen = buf.readInt();
- byte[] response = new byte[responseLen];
- buf.readBytes(response);
+ byte[] response = Encoders.ByteArrays.decode(buf);
return new RpcResponse(requestId, response);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
index 75c4a3981a..009dbcf013 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -50,33 +50,6 @@ public class JavaUtils {
}
}
- // TODO: Make this configurable, do not use Java serialization!
- public static <T> T deserialize(byte[] bytes) {
- try {
- ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes));
- Object out = is.readObject();
- is.close();
- return (T) out;
- } catch (ClassNotFoundException e) {
- throw new RuntimeException("Could not deserialize object", e);
- } catch (IOException e) {
- throw new RuntimeException("Could not deserialize object", e);
- }
- }
-
- // TODO: Make this configurable, do not use Java serialization!
- public static byte[] serialize(Object object) {
- try {
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
- ObjectOutputStream os = new ObjectOutputStream(baos);
- os.writeObject(object);
- os.close();
- return baos.toByteArray();
- } catch (IOException e) {
- throw new RuntimeException("Could not serialize object", e);
- }
- }
-
/** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */
public static int nonNegativeHash(Object obj) {
if (obj == null) { return 0; }
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
index 599cc6428c..cad76ab7aa 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
@@ -17,10 +17,10 @@
package org.apache.spark.network.sasl;
-import com.google.common.base.Charsets;
import io.netty.buffer.ByteBuf;
import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.Encoders;
/**
* Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
@@ -42,18 +42,14 @@ class SaslMessage implements Encodable {
@Override
public int encodedLength() {
- // tag + appIdLength + appId + payloadLength + payload
- return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length;
+ return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload);
}
@Override
public void encode(ByteBuf buf) {
buf.writeByte(TAG_BYTE);
- byte[] idBytes = appId.getBytes(Charsets.UTF_8);
- buf.writeInt(idBytes.length);
- buf.writeBytes(idBytes);
- buf.writeInt(payload.length);
- buf.writeBytes(payload);
+ Encoders.Strings.encode(buf, appId);
+ Encoders.ByteArrays.encode(buf, payload);
}
public static SaslMessage decode(ByteBuf buf) {
@@ -62,14 +58,8 @@ class SaslMessage implements Encodable {
+ " (maybe your client does not have SASL enabled?)");
}
- int idLength = buf.readInt();
- byte[] idBytes = new byte[idLength];
- buf.readBytes(idBytes);
-
- int payloadLength = buf.readInt();
- byte[] payload = new byte[payloadLength];
- buf.readBytes(payload);
-
- return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload);
+ String appId = Encoders.Strings.decode(buf);
+ byte[] payload = Encoders.ByteArrays.decode(buf);
+ return new SaslMessage(appId, payload);
}
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 75ebf8c7b0..a6db4b2abd 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -24,15 +24,16 @@ import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*;
-
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
-import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
/**
* RPC Handler for a server which can serve shuffle blocks from outside of an Executor process.
@@ -62,12 +63,10 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
@Override
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
- Object msgObj = JavaUtils.deserialize(message);
-
- logger.trace("Received message: " + msgObj);
+ BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message);
- if (msgObj instanceof OpenShuffleBlocks) {
- OpenShuffleBlocks msg = (OpenShuffleBlocks) msgObj;
+ if (msgObj instanceof OpenBlocks) {
+ OpenBlocks msg = (OpenBlocks) msgObj;
List<ManagedBuffer> blocks = Lists.newArrayList();
for (String blockId : msg.blockIds) {
@@ -75,8 +74,7 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
}
long streamId = streamManager.registerStream(blocks.iterator());
logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
- callback.onSuccess(JavaUtils.serialize(
- new ShuffleStreamHandle(streamId, msg.blockIds.length)));
+ callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray());
} else if (msgObj instanceof RegisterExecutor) {
RegisterExecutor msg = (RegisterExecutor) msgObj;
@@ -84,8 +82,7 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
callback.onSuccess(new byte[0]);
} else {
- throw new UnsupportedOperationException(String.format(
- "Unexpected message: %s (class = %s)", msgObj, msgObj.getClass()));
+ throw new UnsupportedOperationException("Unexpected message: " + msgObj);
}
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java
index 98fcfb82aa..ffb7faa3db 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java
@@ -35,6 +35,7 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.util.JavaUtils;
/**
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 27884b82c8..6e8018b723 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -31,8 +31,8 @@ import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.sasl.SaslClientBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.NoOpRpcHandler;
-import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor;
-import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
import org.apache.spark.network.util.TransportConf;
/**
@@ -91,8 +91,7 @@ public class ExternalShuffleClient extends ShuffleClient {
public void createAndStart(String[] blockIds, BlockFetchingListener listener)
throws IOException {
TransportClient client = clientFactory.createClient(host, port);
- new OneForOneBlockFetcher(client, blockIds, listener)
- .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
+ new OneForOneBlockFetcher(client, appId, execId, blockIds, listener).start();
}
};
@@ -128,9 +127,8 @@ public class ExternalShuffleClient extends ShuffleClient {
ExecutorShuffleInfo executorInfo) throws IOException {
assert appId != null : "Called before init()";
TransportClient client = clientFactory.createClient(host, port);
- byte[] registerExecutorMessage =
- JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo));
- client.sendRpcSync(registerExecutorMessage, 5000 /* timeoutMs */);
+ byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray();
+ client.sendRpcSync(registerMessage, 5000 /* timeoutMs */);
}
@Override
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java
deleted file mode 100644
index e79420ed82..0000000000
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.shuffle;
-
-import java.io.Serializable;
-import java.util.Arrays;
-
-import com.google.common.base.Objects;
-
-/** Messages handled by the {@link ExternalShuffleBlockHandler}. */
-public class ExternalShuffleMessages {
-
- /** Request to read a set of shuffle blocks. Returns [[ShuffleStreamHandle]]. */
- public static class OpenShuffleBlocks implements Serializable {
- public final String appId;
- public final String execId;
- public final String[] blockIds;
-
- public OpenShuffleBlocks(String appId, String execId, String[] blockIds) {
- this.appId = appId;
- this.execId = execId;
- this.blockIds = blockIds;
- }
-
- @Override
- public int hashCode() {
- return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds);
- }
-
- @Override
- public String toString() {
- return Objects.toStringHelper(this)
- .add("appId", appId)
- .add("execId", execId)
- .add("blockIds", Arrays.toString(blockIds))
- .toString();
- }
-
- @Override
- public boolean equals(Object other) {
- if (other != null && other instanceof OpenShuffleBlocks) {
- OpenShuffleBlocks o = (OpenShuffleBlocks) other;
- return Objects.equal(appId, o.appId)
- && Objects.equal(execId, o.execId)
- && Arrays.equals(blockIds, o.blockIds);
- }
- return false;
- }
- }
-
- /** Initial registration message between an executor and its local shuffle server. */
- public static class RegisterExecutor implements Serializable {
- public final String appId;
- public final String execId;
- public final ExecutorShuffleInfo executorInfo;
-
- public RegisterExecutor(
- String appId,
- String execId,
- ExecutorShuffleInfo executorInfo) {
- this.appId = appId;
- this.execId = execId;
- this.executorInfo = executorInfo;
- }
-
- @Override
- public int hashCode() {
- return Objects.hashCode(appId, execId, executorInfo);
- }
-
- @Override
- public String toString() {
- return Objects.toStringHelper(this)
- .add("appId", appId)
- .add("execId", execId)
- .add("executorInfo", executorInfo)
- .toString();
- }
-
- @Override
- public boolean equals(Object other) {
- if (other != null && other instanceof RegisterExecutor) {
- RegisterExecutor o = (RegisterExecutor) other;
- return Objects.equal(appId, o.appId)
- && Objects.equal(execId, o.execId)
- && Objects.equal(executorInfo, o.executorInfo);
- }
- return false;
- }
- }
-}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index 9e77a1f68c..8ed2e0b39a 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -26,6 +26,9 @@ import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.JavaUtils;
/**
@@ -41,17 +44,21 @@ public class OneForOneBlockFetcher {
private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);
private final TransportClient client;
+ private final OpenBlocks openMessage;
private final String[] blockIds;
private final BlockFetchingListener listener;
private final ChunkReceivedCallback chunkCallback;
- private ShuffleStreamHandle streamHandle = null;
+ private StreamHandle streamHandle = null;
public OneForOneBlockFetcher(
TransportClient client,
+ String appId,
+ String execId,
String[] blockIds,
BlockFetchingListener listener) {
this.client = client;
+ this.openMessage = new OpenBlocks(appId, execId, blockIds);
this.blockIds = blockIds;
this.listener = listener;
this.chunkCallback = new ChunkCallback();
@@ -76,18 +83,18 @@ public class OneForOneBlockFetcher {
/**
* Begins the fetching process, calling the listener with every block fetched.
* The given message will be serialized with the Java serializer, and the RPC must return a
- * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling.
+ * {@link StreamHandle}. We will send all fetch requests immediately, without throttling.
*/
- public void start(Object openBlocksMessage) {
+ public void start() {
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
- client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() {
+ client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() {
@Override
public void onSuccess(byte[] response) {
try {
- streamHandle = JavaUtils.deserialize(response);
+ streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);
logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);
// Immediately request all chunks -- we expect that the total size of the request is
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
new file mode 100644
index 0000000000..b4b13b8a6e
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.protocol.Encodable;
+
+/**
+ * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or
+ * by Spark's NettyBlockTransferService.
+ *
+ * At a high level:
+ * - OpenBlock is handled by both services, but only services shuffle files for the external
+ * shuffle service. It returns a StreamHandle.
+ * - UploadBlock is only handled by the NettyBlockTransferService.
+ * - RegisterExecutor is only handled by the external shuffle service.
+ */
+public abstract class BlockTransferMessage implements Encodable {
+ protected abstract Type type();
+
+ /** Preceding every serialized message is its type, which allows us to deserialize it. */
+ public static enum Type {
+ OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3);
+
+ private final byte id;
+
+ private Type(int id) {
+ assert id < 128 : "Cannot have more than 128 message types";
+ this.id = (byte) id;
+ }
+
+ public byte id() { return id; }
+ }
+
+ // NB: Java does not support static methods in interfaces, so we must put this in a static class.
+ public static class Decoder {
+ /** Deserializes the 'type' byte followed by the message itself. */
+ public static BlockTransferMessage fromByteArray(byte[] msg) {
+ ByteBuf buf = Unpooled.wrappedBuffer(msg);
+ byte type = buf.readByte();
+ switch (type) {
+ case 0: return OpenBlocks.decode(buf);
+ case 1: return UploadBlock.decode(buf);
+ case 2: return RegisterExecutor.decode(buf);
+ case 3: return StreamHandle.decode(buf);
+ default: throw new IllegalArgumentException("Unknown message type: " + type);
+ }
+ }
+ }
+
+ /** Serializes the 'type' byte followed by the message itself. */
+ public byte[] toByteArray() {
+ ByteBuf buf = Unpooled.buffer(encodedLength());
+ buf.writeByte(type().id);
+ encode(buf);
+ assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes();
+ return buf.array();
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java
index d45e64656a..cadc8e8369 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java
@@ -15,21 +15,24 @@
* limitations under the License.
*/
-package org.apache.spark.network.shuffle;
+package org.apache.spark.network.shuffle.protocol;
-import java.io.Serializable;
import java.util.Arrays;
import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.Encoders;
/** Contains all configuration necessary for locating the shuffle files of an executor. */
-public class ExecutorShuffleInfo implements Serializable {
+public class ExecutorShuffleInfo implements Encodable {
/** The base set of local directories that the executor stores its shuffle files in. */
- final String[] localDirs;
+ public final String[] localDirs;
/** Number of subdirectories created within each localDir. */
- final int subDirsPerLocalDir;
+ public final int subDirsPerLocalDir;
/** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */
- final String shuffleManager;
+ public final String shuffleManager;
public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) {
this.localDirs = localDirs;
@@ -61,4 +64,25 @@ public class ExecutorShuffleInfo implements Serializable {
}
return false;
}
+
+ @Override
+ public int encodedLength() {
+ return Encoders.StringArrays.encodedLength(localDirs)
+ + 4 // int
+ + Encoders.Strings.encodedLength(shuffleManager);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.StringArrays.encode(buf, localDirs);
+ buf.writeInt(subDirsPerLocalDir);
+ Encoders.Strings.encode(buf, shuffleManager);
+ }
+
+ public static ExecutorShuffleInfo decode(ByteBuf buf) {
+ String[] localDirs = Encoders.StringArrays.decode(buf);
+ int subDirsPerLocalDir = buf.readInt();
+ String shuffleManager = Encoders.Strings.decode(buf);
+ return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager);
+ }
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java
new file mode 100644
index 0000000000..60485bace6
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+/** Request to read a set of blocks. Returns {@link StreamHandle}. */
+public class OpenBlocks extends BlockTransferMessage {
+ public final String appId;
+ public final String execId;
+ public final String[] blockIds;
+
+ public OpenBlocks(String appId, String execId, String[] blockIds) {
+ this.appId = appId;
+ this.execId = execId;
+ this.blockIds = blockIds;
+ }
+
+ @Override
+ protected Type type() { return Type.OPEN_BLOCKS; }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("execId", execId)
+ .add("blockIds", Arrays.toString(blockIds))
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof OpenBlocks) {
+ OpenBlocks o = (OpenBlocks) other;
+ return Objects.equal(appId, o.appId)
+ && Objects.equal(execId, o.execId)
+ && Arrays.equals(blockIds, o.blockIds);
+ }
+ return false;
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId)
+ + Encoders.Strings.encodedLength(execId)
+ + Encoders.StringArrays.encodedLength(blockIds);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ Encoders.Strings.encode(buf, execId);
+ Encoders.StringArrays.encode(buf, blockIds);
+ }
+
+ public static OpenBlocks decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ String execId = Encoders.Strings.decode(buf);
+ String[] blockIds = Encoders.StringArrays.decode(buf);
+ return new OpenBlocks(appId, execId, blockIds);
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java
new file mode 100644
index 0000000000..38acae3b31
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+/**
+ * Initial registration message between an executor and its local shuffle server.
+ * Returns nothing (empty bye array).
+ */
+public class RegisterExecutor extends BlockTransferMessage {
+ public final String appId;
+ public final String execId;
+ public final ExecutorShuffleInfo executorInfo;
+
+ public RegisterExecutor(
+ String appId,
+ String execId,
+ ExecutorShuffleInfo executorInfo) {
+ this.appId = appId;
+ this.execId = execId;
+ this.executorInfo = executorInfo;
+ }
+
+ @Override
+ protected Type type() { return Type.REGISTER_EXECUTOR; }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, execId, executorInfo);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("execId", execId)
+ .add("executorInfo", executorInfo)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof RegisterExecutor) {
+ RegisterExecutor o = (RegisterExecutor) other;
+ return Objects.equal(appId, o.appId)
+ && Objects.equal(execId, o.execId)
+ && Objects.equal(executorInfo, o.executorInfo);
+ }
+ return false;
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId)
+ + Encoders.Strings.encodedLength(execId)
+ + executorInfo.encodedLength();
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ Encoders.Strings.encode(buf, execId);
+ executorInfo.encode(buf);
+ }
+
+ public static RegisterExecutor decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ String execId = Encoders.Strings.decode(buf);
+ ExecutorShuffleInfo executorShuffleInfo = ExecutorShuffleInfo.decode(buf);
+ return new RegisterExecutor(appId, execId, executorShuffleInfo);
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java
index 9c94691224..21369c8cfb 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java
@@ -15,27 +15,30 @@
* limitations under the License.
*/
-package org.apache.spark.network.shuffle;
+package org.apache.spark.network.shuffle.protocol;
import java.io.Serializable;
-import java.util.Arrays;
import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
/**
* Identifier for a fixed number of chunks to read from a stream created by an "open blocks"
- * message. This is used by {@link OneForOneBlockFetcher}.
+ * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}.
*/
-public class ShuffleStreamHandle implements Serializable {
+public class StreamHandle extends BlockTransferMessage {
public final long streamId;
public final int numChunks;
- public ShuffleStreamHandle(long streamId, int numChunks) {
+ public StreamHandle(long streamId, int numChunks) {
this.streamId = streamId;
this.numChunks = numChunks;
}
@Override
+ protected Type type() { return Type.STREAM_HANDLE; }
+
+ @Override
public int hashCode() {
return Objects.hashCode(streamId, numChunks);
}
@@ -50,11 +53,28 @@ public class ShuffleStreamHandle implements Serializable {
@Override
public boolean equals(Object other) {
- if (other != null && other instanceof ShuffleStreamHandle) {
- ShuffleStreamHandle o = (ShuffleStreamHandle) other;
+ if (other != null && other instanceof StreamHandle) {
+ StreamHandle o = (StreamHandle) other;
return Objects.equal(streamId, o.streamId)
&& Objects.equal(numChunks, o.numChunks);
}
return false;
}
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(streamId);
+ buf.writeInt(numChunks);
+ }
+
+ public static StreamHandle decode(ByteBuf buf) {
+ long streamId = buf.readLong();
+ int numChunks = buf.readInt();
+ return new StreamHandle(streamId, numChunks);
+ }
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java
new file mode 100644
index 0000000000..38abe29cc5
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
+public class UploadBlock extends BlockTransferMessage {
+ public final String appId;
+ public final String execId;
+ public final String blockId;
+ // TODO: StorageLevel is serialized separately in here because StorageLevel is not available in
+ // this package. We should avoid this hack.
+ public final byte[] metadata;
+ public final byte[] blockData;
+
+ /**
+ * @param metadata Meta-information about block, typically StorageLevel.
+ * @param blockData The actual block's bytes.
+ */
+ public UploadBlock(
+ String appId,
+ String execId,
+ String blockId,
+ byte[] metadata,
+ byte[] blockData) {
+ this.appId = appId;
+ this.execId = execId;
+ this.blockId = blockId;
+ this.metadata = metadata;
+ this.blockData = blockData;
+ }
+
+ @Override
+ protected Type type() { return Type.UPLOAD_BLOCK; }
+
+ @Override
+ public int hashCode() {
+ int objectsHashCode = Objects.hashCode(appId, execId, blockId);
+ return (objectsHashCode * 41 + Arrays.hashCode(metadata)) * 41 + Arrays.hashCode(blockData);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("execId", execId)
+ .add("blockId", blockId)
+ .add("metadata size", metadata.length)
+ .add("block size", blockData.length)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof UploadBlock) {
+ UploadBlock o = (UploadBlock) other;
+ return Objects.equal(appId, o.appId)
+ && Objects.equal(execId, o.execId)
+ && Objects.equal(blockId, o.blockId)
+ && Arrays.equals(metadata, o.metadata)
+ && Arrays.equals(blockData, o.blockData);
+ }
+ return false;
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId)
+ + Encoders.Strings.encodedLength(execId)
+ + Encoders.Strings.encodedLength(blockId)
+ + Encoders.ByteArrays.encodedLength(metadata)
+ + Encoders.ByteArrays.encodedLength(blockData);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ Encoders.Strings.encode(buf, execId);
+ Encoders.Strings.encode(buf, blockId);
+ Encoders.ByteArrays.encode(buf, metadata);
+ Encoders.ByteArrays.encode(buf, blockData);
+ }
+
+ public static UploadBlock decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ String execId = Encoders.Strings.decode(buf);
+ String blockId = Encoders.Strings.decode(buf);
+ byte[] metadata = Encoders.ByteArrays.decode(buf);
+ byte[] blockData = Encoders.ByteArrays.decode(buf);
+ return new UploadBlock(appId, execId, blockId, metadata, blockData);
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
index ee9482b49c..d65de9ca55 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
@@ -21,31 +21,24 @@ import org.junit.Test;
import static org.junit.Assert.*;
-import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.shuffle.protocol.*;
-import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*;
-
-public class ShuffleMessagesSuite {
+/** Verifies that all BlockTransferMessages can be serialized correctly. */
+public class BlockTransferMessagesSuite {
@Test
public void serializeOpenShuffleBlocks() {
- OpenShuffleBlocks msg = new OpenShuffleBlocks("app-1", "exec-2",
- new String[] { "block0", "block1" });
- OpenShuffleBlocks msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
- assertEquals(msg, msg2);
+ checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" }));
+ checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo(
+ new String[] { "/local1", "/local2" }, 32, "MyShuffleManager")));
+ checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 },
+ new byte[] { 4, 5, 6, 7} ));
+ checkSerializeDeserialize(new StreamHandle(12345, 16));
}
- @Test
- public void serializeRegisterExecutor() {
- RegisterExecutor msg = new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo(
- new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"));
- RegisterExecutor msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
- assertEquals(msg, msg2);
- }
-
- @Test
- public void serializeShuffleStreamHandle() {
- ShuffleStreamHandle msg = new ShuffleStreamHandle(12345, 16);
- ShuffleStreamHandle msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
+ private void checkSerializeDeserialize(BlockTransferMessage msg) {
+ BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray());
assertEquals(msg, msg2);
+ assertEquals(msg.hashCode(), msg2.hashCode());
+ assertEquals(msg.toString(), msg2.toString());
}
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index 7939cb4d32..3f9fe1681c 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -24,8 +24,6 @@ import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
-import static org.apache.spark.network.shuffle.ExternalShuffleMessages.OpenShuffleBlocks;
-import static org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor;
import static org.junit.Assert.*;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.*;
@@ -36,7 +34,12 @@ import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
-import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
+import org.apache.spark.network.shuffle.protocol.UploadBlock;
public class ExternalShuffleBlockHandlerSuite {
TransportClient client = mock(TransportClient.class);
@@ -57,8 +60,7 @@ public class ExternalShuffleBlockHandlerSuite {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort");
- byte[] registerMessage = JavaUtils.serialize(
- new RegisterExecutor("app0", "exec1", config));
+ byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray();
handler.receive(client, registerMessage, callback);
verify(blockManager, times(1)).registerExecutor("app0", "exec1", config);
@@ -75,9 +77,8 @@ public class ExternalShuffleBlockHandlerSuite {
ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
when(blockManager.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker);
when(blockManager.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker);
- byte[] openBlocksMessage = JavaUtils.serialize(
- new OpenShuffleBlocks("app0", "exec1", new String[] { "b0", "b1" }));
- handler.receive(client, openBlocksMessage, callback);
+ byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray();
+ handler.receive(client, openBlocks, callback);
verify(blockManager, times(1)).getBlockData("app0", "exec1", "b0");
verify(blockManager, times(1)).getBlockData("app0", "exec1", "b1");
@@ -85,7 +86,8 @@ public class ExternalShuffleBlockHandlerSuite {
verify(callback, times(1)).onSuccess(response.capture());
verify(callback, never()).onFailure((Throwable) any());
- ShuffleStreamHandle handle = JavaUtils.deserialize(response.getValue());
+ StreamHandle handle =
+ (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue());
assertEquals(2, handle.numChunks);
ArgumentCaptor<Iterator> stream = ArgumentCaptor.forClass(Iterator.class);
@@ -100,18 +102,17 @@ public class ExternalShuffleBlockHandlerSuite {
public void testBadMessages() {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
- byte[] unserializableMessage = new byte[] { 0x12, 0x34, 0x56 };
+ byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 };
try {
- handler.receive(client, unserializableMessage, callback);
+ handler.receive(client, unserializableMsg, callback);
fail("Should have thrown");
} catch (Exception e) {
// pass
}
- byte[] unexpectedMessage = JavaUtils.serialize(
- new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"));
+ byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray();
try {
- handler.receive(client, unexpectedMessage, callback);
+ handler.receive(client, unexpectedMsg, callback);
fail("Should have thrown");
} catch (UnsupportedOperationException e) {
// pass
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index 3bea5b0f25..687bde59fd 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -42,6 +42,7 @@ import org.apache.spark.network.TransportContext;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
index 848c88f743..8afceab1d5 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
@@ -31,6 +31,7 @@ import org.apache.spark.network.sasl.SaslRpcHandler;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
index c18346f696..842741e3d3 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
@@ -40,7 +40,9 @@ import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
public class OneForOneBlockFetcherSuite {
@Test
@@ -119,17 +121,19 @@ public class OneForOneBlockFetcherSuite {
private BlockFetchingListener fetchBlocks(final LinkedHashMap<String, ManagedBuffer> blocks) {
TransportClient client = mock(TransportClient.class);
BlockFetchingListener listener = mock(BlockFetchingListener.class);
- String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
- OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client, blockIds, listener);
+ final String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
+ OneForOneBlockFetcher fetcher =
+ new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener);
// Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
- String message = JavaUtils.deserialize((byte[]) invocationOnMock.getArguments()[0]);
+ BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray(
+ (byte[]) invocationOnMock.getArguments()[0]);
RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1];
- callback.onSuccess(JavaUtils.serialize(new ShuffleStreamHandle(123, blocks.size())));
- assertEquals("OpenZeBlocks", message);
+ callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray());
+ assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message);
return null;
}
}).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any());
@@ -161,7 +165,7 @@ public class OneForOneBlockFetcherSuite {
}
}).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any());
- fetcher.start("OpenZeBlocks");
+ fetcher.start();
return listener;
}
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
index 337b5c7bdb..76639114df 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
@@ -25,6 +25,8 @@ import java.io.OutputStream;
import com.google.common.io.Files;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+
/**
* Manages some sort- and hash-based shuffle data, including the creation
* and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}.