aboutsummaryrefslogtreecommitdiff
path: root/network
diff options
context:
space:
mode:
Diffstat (limited to 'network')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java12
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java93
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java12
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java27
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java24
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java21
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java1
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java12
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java106
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java17
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java76
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java (renamed from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java)36
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java87
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java91
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java (renamed from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java)34
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java113
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java (renamed from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java)33
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java29
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java1
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java1
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java18
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java2
24 files changed, 608 insertions, 256 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
index 152af98ced..986957c150 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -38,23 +38,19 @@ public final class ChunkFetchFailure implements ResponseMessage {
@Override
public int encodedLength() {
- return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString);
}
@Override
public void encode(ByteBuf buf) {
streamChunkId.encode(buf);
- byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
- buf.writeInt(errorBytes.length);
- buf.writeBytes(errorBytes);
+ Encoders.Strings.encode(buf, errorString);
}
public static ChunkFetchFailure decode(ByteBuf buf) {
StreamChunkId streamChunkId = StreamChunkId.decode(buf);
- int numErrorStringBytes = buf.readInt();
- byte[] errorBytes = new byte[numErrorStringBytes];
- buf.readBytes(errorBytes);
- return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8));
+ String errorString = Encoders.Strings.decode(buf);
+ return new ChunkFetchFailure(streamChunkId, errorString);
}
@Override
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java
new file mode 100644
index 0000000000..873c694250
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+
+import com.google.common.base.Charsets;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+/** Provides a canonical set of Encoders for simple types. */
+public class Encoders {
+
+ /** Strings are encoded with their length followed by UTF-8 bytes. */
+ public static class Strings {
+ public static int encodedLength(String s) {
+ return 4 + s.getBytes(Charsets.UTF_8).length;
+ }
+
+ public static void encode(ByteBuf buf, String s) {
+ byte[] bytes = s.getBytes(Charsets.UTF_8);
+ buf.writeInt(bytes.length);
+ buf.writeBytes(bytes);
+ }
+
+ public static String decode(ByteBuf buf) {
+ int length = buf.readInt();
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return new String(bytes, Charsets.UTF_8);
+ }
+ }
+
+ /** Byte arrays are encoded with their length followed by bytes. */
+ public static class ByteArrays {
+ public static int encodedLength(byte[] arr) {
+ return 4 + arr.length;
+ }
+
+ public static void encode(ByteBuf buf, byte[] arr) {
+ buf.writeInt(arr.length);
+ buf.writeBytes(arr);
+ }
+
+ public static byte[] decode(ByteBuf buf) {
+ int length = buf.readInt();
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return bytes;
+ }
+ }
+
+ /** String arrays are encoded with the number of strings followed by per-String encoding. */
+ public static class StringArrays {
+ public static int encodedLength(String[] strings) {
+ int totalLength = 4;
+ for (String s : strings) {
+ totalLength += Strings.encodedLength(s);
+ }
+ return totalLength;
+ }
+
+ public static void encode(ByteBuf buf, String[] strings) {
+ buf.writeInt(strings.length);
+ for (String s : strings) {
+ Strings.encode(buf, s);
+ }
+ }
+
+ public static String[] decode(ByteBuf buf) {
+ int numStrings = buf.readInt();
+ String[] strings = new String[numStrings];
+ for (int i = 0; i < strings.length; i ++) {
+ strings[i] = Strings.decode(buf);
+ }
+ return strings;
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
index e239d4ffbd..ebd764eb5e 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
@@ -36,23 +36,19 @@ public final class RpcFailure implements ResponseMessage {
@Override
public int encodedLength() {
- return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ return 8 + Encoders.Strings.encodedLength(errorString);
}
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
- buf.writeInt(errorBytes.length);
- buf.writeBytes(errorBytes);
+ Encoders.Strings.encode(buf, errorString);
}
public static RpcFailure decode(ByteBuf buf) {
long requestId = buf.readLong();
- int numErrorStringBytes = buf.readInt();
- byte[] errorBytes = new byte[numErrorStringBytes];
- buf.readBytes(errorBytes);
- return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8));
+ String errorString = Encoders.Strings.decode(buf);
+ return new RpcFailure(requestId, errorString);
}
@Override
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
index 099e934ae0..cdee0b0e03 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
@@ -44,21 +44,18 @@ public final class RpcRequest implements RequestMessage {
@Override
public int encodedLength() {
- return 8 + 4 + message.length;
+ return 8 + Encoders.ByteArrays.encodedLength(message);
}
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- buf.writeInt(message.length);
- buf.writeBytes(message);
+ Encoders.ByteArrays.encode(buf, message);
}
public static RpcRequest decode(ByteBuf buf) {
long requestId = buf.readLong();
- int messageLen = buf.readInt();
- byte[] message = new byte[messageLen];
- buf.readBytes(message);
+ byte[] message = Encoders.ByteArrays.decode(buf);
return new RpcRequest(requestId, message);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
index ed47947832..0a62e09a81 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
@@ -36,20 +36,17 @@ public final class RpcResponse implements ResponseMessage {
public Type type() { return Type.RpcResponse; }
@Override
- public int encodedLength() { return 8 + 4 + response.length; }
+ public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); }
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- buf.writeInt(response.length);
- buf.writeBytes(response);
+ Encoders.ByteArrays.encode(buf, response);
}
public static RpcResponse decode(ByteBuf buf) {
long requestId = buf.readLong();
- int responseLen = buf.readInt();
- byte[] response = new byte[responseLen];
- buf.readBytes(response);
+ byte[] response = Encoders.ByteArrays.decode(buf);
return new RpcResponse(requestId, response);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
index 75c4a3981a..009dbcf013 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -50,33 +50,6 @@ public class JavaUtils {
}
}
- // TODO: Make this configurable, do not use Java serialization!
- public static <T> T deserialize(byte[] bytes) {
- try {
- ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes));
- Object out = is.readObject();
- is.close();
- return (T) out;
- } catch (ClassNotFoundException e) {
- throw new RuntimeException("Could not deserialize object", e);
- } catch (IOException e) {
- throw new RuntimeException("Could not deserialize object", e);
- }
- }
-
- // TODO: Make this configurable, do not use Java serialization!
- public static byte[] serialize(Object object) {
- try {
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
- ObjectOutputStream os = new ObjectOutputStream(baos);
- os.writeObject(object);
- os.close();
- return baos.toByteArray();
- } catch (IOException e) {
- throw new RuntimeException("Could not serialize object", e);
- }
- }
-
/** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */
public static int nonNegativeHash(Object obj) {
if (obj == null) { return 0; }
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
index 599cc6428c..cad76ab7aa 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
@@ -17,10 +17,10 @@
package org.apache.spark.network.sasl;
-import com.google.common.base.Charsets;
import io.netty.buffer.ByteBuf;
import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.Encoders;
/**
* Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
@@ -42,18 +42,14 @@ class SaslMessage implements Encodable {
@Override
public int encodedLength() {
- // tag + appIdLength + appId + payloadLength + payload
- return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length;
+ return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload);
}
@Override
public void encode(ByteBuf buf) {
buf.writeByte(TAG_BYTE);
- byte[] idBytes = appId.getBytes(Charsets.UTF_8);
- buf.writeInt(idBytes.length);
- buf.writeBytes(idBytes);
- buf.writeInt(payload.length);
- buf.writeBytes(payload);
+ Encoders.Strings.encode(buf, appId);
+ Encoders.ByteArrays.encode(buf, payload);
}
public static SaslMessage decode(ByteBuf buf) {
@@ -62,14 +58,8 @@ class SaslMessage implements Encodable {
+ " (maybe your client does not have SASL enabled?)");
}
- int idLength = buf.readInt();
- byte[] idBytes = new byte[idLength];
- buf.readBytes(idBytes);
-
- int payloadLength = buf.readInt();
- byte[] payload = new byte[payloadLength];
- buf.readBytes(payload);
-
- return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload);
+ String appId = Encoders.Strings.decode(buf);
+ byte[] payload = Encoders.ByteArrays.decode(buf);
+ return new SaslMessage(appId, payload);
}
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 75ebf8c7b0..a6db4b2abd 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -24,15 +24,16 @@ import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*;
-
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
-import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
/**
* RPC Handler for a server which can serve shuffle blocks from outside of an Executor process.
@@ -62,12 +63,10 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
@Override
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
- Object msgObj = JavaUtils.deserialize(message);
-
- logger.trace("Received message: " + msgObj);
+ BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message);
- if (msgObj instanceof OpenShuffleBlocks) {
- OpenShuffleBlocks msg = (OpenShuffleBlocks) msgObj;
+ if (msgObj instanceof OpenBlocks) {
+ OpenBlocks msg = (OpenBlocks) msgObj;
List<ManagedBuffer> blocks = Lists.newArrayList();
for (String blockId : msg.blockIds) {
@@ -75,8 +74,7 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
}
long streamId = streamManager.registerStream(blocks.iterator());
logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
- callback.onSuccess(JavaUtils.serialize(
- new ShuffleStreamHandle(streamId, msg.blockIds.length)));
+ callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray());
} else if (msgObj instanceof RegisterExecutor) {
RegisterExecutor msg = (RegisterExecutor) msgObj;
@@ -84,8 +82,7 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
callback.onSuccess(new byte[0]);
} else {
- throw new UnsupportedOperationException(String.format(
- "Unexpected message: %s (class = %s)", msgObj, msgObj.getClass()));
+ throw new UnsupportedOperationException("Unexpected message: " + msgObj);
}
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java
index 98fcfb82aa..ffb7faa3db 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java
@@ -35,6 +35,7 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.util.JavaUtils;
/**
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 27884b82c8..6e8018b723 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -31,8 +31,8 @@ import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.sasl.SaslClientBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.NoOpRpcHandler;
-import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor;
-import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
import org.apache.spark.network.util.TransportConf;
/**
@@ -91,8 +91,7 @@ public class ExternalShuffleClient extends ShuffleClient {
public void createAndStart(String[] blockIds, BlockFetchingListener listener)
throws IOException {
TransportClient client = clientFactory.createClient(host, port);
- new OneForOneBlockFetcher(client, blockIds, listener)
- .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
+ new OneForOneBlockFetcher(client, appId, execId, blockIds, listener).start();
}
};
@@ -128,9 +127,8 @@ public class ExternalShuffleClient extends ShuffleClient {
ExecutorShuffleInfo executorInfo) throws IOException {
assert appId != null : "Called before init()";
TransportClient client = clientFactory.createClient(host, port);
- byte[] registerExecutorMessage =
- JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo));
- client.sendRpcSync(registerExecutorMessage, 5000 /* timeoutMs */);
+ byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray();
+ client.sendRpcSync(registerMessage, 5000 /* timeoutMs */);
}
@Override
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java
deleted file mode 100644
index e79420ed82..0000000000
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.shuffle;
-
-import java.io.Serializable;
-import java.util.Arrays;
-
-import com.google.common.base.Objects;
-
-/** Messages handled by the {@link ExternalShuffleBlockHandler}. */
-public class ExternalShuffleMessages {
-
- /** Request to read a set of shuffle blocks. Returns [[ShuffleStreamHandle]]. */
- public static class OpenShuffleBlocks implements Serializable {
- public final String appId;
- public final String execId;
- public final String[] blockIds;
-
- public OpenShuffleBlocks(String appId, String execId, String[] blockIds) {
- this.appId = appId;
- this.execId = execId;
- this.blockIds = blockIds;
- }
-
- @Override
- public int hashCode() {
- return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds);
- }
-
- @Override
- public String toString() {
- return Objects.toStringHelper(this)
- .add("appId", appId)
- .add("execId", execId)
- .add("blockIds", Arrays.toString(blockIds))
- .toString();
- }
-
- @Override
- public boolean equals(Object other) {
- if (other != null && other instanceof OpenShuffleBlocks) {
- OpenShuffleBlocks o = (OpenShuffleBlocks) other;
- return Objects.equal(appId, o.appId)
- && Objects.equal(execId, o.execId)
- && Arrays.equals(blockIds, o.blockIds);
- }
- return false;
- }
- }
-
- /** Initial registration message between an executor and its local shuffle server. */
- public static class RegisterExecutor implements Serializable {
- public final String appId;
- public final String execId;
- public final ExecutorShuffleInfo executorInfo;
-
- public RegisterExecutor(
- String appId,
- String execId,
- ExecutorShuffleInfo executorInfo) {
- this.appId = appId;
- this.execId = execId;
- this.executorInfo = executorInfo;
- }
-
- @Override
- public int hashCode() {
- return Objects.hashCode(appId, execId, executorInfo);
- }
-
- @Override
- public String toString() {
- return Objects.toStringHelper(this)
- .add("appId", appId)
- .add("execId", execId)
- .add("executorInfo", executorInfo)
- .toString();
- }
-
- @Override
- public boolean equals(Object other) {
- if (other != null && other instanceof RegisterExecutor) {
- RegisterExecutor o = (RegisterExecutor) other;
- return Objects.equal(appId, o.appId)
- && Objects.equal(execId, o.execId)
- && Objects.equal(executorInfo, o.executorInfo);
- }
- return false;
- }
- }
-}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index 9e77a1f68c..8ed2e0b39a 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -26,6 +26,9 @@ import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.JavaUtils;
/**
@@ -41,17 +44,21 @@ public class OneForOneBlockFetcher {
private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);
private final TransportClient client;
+ private final OpenBlocks openMessage;
private final String[] blockIds;
private final BlockFetchingListener listener;
private final ChunkReceivedCallback chunkCallback;
- private ShuffleStreamHandle streamHandle = null;
+ private StreamHandle streamHandle = null;
public OneForOneBlockFetcher(
TransportClient client,
+ String appId,
+ String execId,
String[] blockIds,
BlockFetchingListener listener) {
this.client = client;
+ this.openMessage = new OpenBlocks(appId, execId, blockIds);
this.blockIds = blockIds;
this.listener = listener;
this.chunkCallback = new ChunkCallback();
@@ -76,18 +83,18 @@ public class OneForOneBlockFetcher {
/**
* Begins the fetching process, calling the listener with every block fetched.
* The given message will be serialized with the Java serializer, and the RPC must return a
- * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling.
+ * {@link StreamHandle}. We will send all fetch requests immediately, without throttling.
*/
- public void start(Object openBlocksMessage) {
+ public void start() {
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
- client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() {
+ client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() {
@Override
public void onSuccess(byte[] response) {
try {
- streamHandle = JavaUtils.deserialize(response);
+ streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);
logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);
// Immediately request all chunks -- we expect that the total size of the request is
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
new file mode 100644
index 0000000000..b4b13b8a6e
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.protocol.Encodable;
+
+/**
+ * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or
+ * by Spark's NettyBlockTransferService.
+ *
+ * At a high level:
+ * - OpenBlock is handled by both services, but only services shuffle files for the external
+ * shuffle service. It returns a StreamHandle.
+ * - UploadBlock is only handled by the NettyBlockTransferService.
+ * - RegisterExecutor is only handled by the external shuffle service.
+ */
+public abstract class BlockTransferMessage implements Encodable {
+ protected abstract Type type();
+
+ /** Preceding every serialized message is its type, which allows us to deserialize it. */
+ public static enum Type {
+ OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3);
+
+ private final byte id;
+
+ private Type(int id) {
+ assert id < 128 : "Cannot have more than 128 message types";
+ this.id = (byte) id;
+ }
+
+ public byte id() { return id; }
+ }
+
+ // NB: Java does not support static methods in interfaces, so we must put this in a static class.
+ public static class Decoder {
+ /** Deserializes the 'type' byte followed by the message itself. */
+ public static BlockTransferMessage fromByteArray(byte[] msg) {
+ ByteBuf buf = Unpooled.wrappedBuffer(msg);
+ byte type = buf.readByte();
+ switch (type) {
+ case 0: return OpenBlocks.decode(buf);
+ case 1: return UploadBlock.decode(buf);
+ case 2: return RegisterExecutor.decode(buf);
+ case 3: return StreamHandle.decode(buf);
+ default: throw new IllegalArgumentException("Unknown message type: " + type);
+ }
+ }
+ }
+
+ /** Serializes the 'type' byte followed by the message itself. */
+ public byte[] toByteArray() {
+ ByteBuf buf = Unpooled.buffer(encodedLength());
+ buf.writeByte(type().id);
+ encode(buf);
+ assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes();
+ return buf.array();
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java
index d45e64656a..cadc8e8369 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java
@@ -15,21 +15,24 @@
* limitations under the License.
*/
-package org.apache.spark.network.shuffle;
+package org.apache.spark.network.shuffle.protocol;
-import java.io.Serializable;
import java.util.Arrays;
import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.Encoders;
/** Contains all configuration necessary for locating the shuffle files of an executor. */
-public class ExecutorShuffleInfo implements Serializable {
+public class ExecutorShuffleInfo implements Encodable {
/** The base set of local directories that the executor stores its shuffle files in. */
- final String[] localDirs;
+ public final String[] localDirs;
/** Number of subdirectories created within each localDir. */
- final int subDirsPerLocalDir;
+ public final int subDirsPerLocalDir;
/** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */
- final String shuffleManager;
+ public final String shuffleManager;
public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) {
this.localDirs = localDirs;
@@ -61,4 +64,25 @@ public class ExecutorShuffleInfo implements Serializable {
}
return false;
}
+
+ @Override
+ public int encodedLength() {
+ return Encoders.StringArrays.encodedLength(localDirs)
+ + 4 // int
+ + Encoders.Strings.encodedLength(shuffleManager);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.StringArrays.encode(buf, localDirs);
+ buf.writeInt(subDirsPerLocalDir);
+ Encoders.Strings.encode(buf, shuffleManager);
+ }
+
+ public static ExecutorShuffleInfo decode(ByteBuf buf) {
+ String[] localDirs = Encoders.StringArrays.decode(buf);
+ int subDirsPerLocalDir = buf.readInt();
+ String shuffleManager = Encoders.Strings.decode(buf);
+ return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager);
+ }
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java
new file mode 100644
index 0000000000..60485bace6
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+/** Request to read a set of blocks. Returns {@link StreamHandle}. */
+public class OpenBlocks extends BlockTransferMessage {
+ public final String appId;
+ public final String execId;
+ public final String[] blockIds;
+
+ public OpenBlocks(String appId, String execId, String[] blockIds) {
+ this.appId = appId;
+ this.execId = execId;
+ this.blockIds = blockIds;
+ }
+
+ @Override
+ protected Type type() { return Type.OPEN_BLOCKS; }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("execId", execId)
+ .add("blockIds", Arrays.toString(blockIds))
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof OpenBlocks) {
+ OpenBlocks o = (OpenBlocks) other;
+ return Objects.equal(appId, o.appId)
+ && Objects.equal(execId, o.execId)
+ && Arrays.equals(blockIds, o.blockIds);
+ }
+ return false;
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId)
+ + Encoders.Strings.encodedLength(execId)
+ + Encoders.StringArrays.encodedLength(blockIds);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ Encoders.Strings.encode(buf, execId);
+ Encoders.StringArrays.encode(buf, blockIds);
+ }
+
+ public static OpenBlocks decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ String execId = Encoders.Strings.decode(buf);
+ String[] blockIds = Encoders.StringArrays.decode(buf);
+ return new OpenBlocks(appId, execId, blockIds);
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java
new file mode 100644
index 0000000000..38acae3b31
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+/**
+ * Initial registration message between an executor and its local shuffle server.
+ * Returns nothing (empty bye array).
+ */
+public class RegisterExecutor extends BlockTransferMessage {
+ public final String appId;
+ public final String execId;
+ public final ExecutorShuffleInfo executorInfo;
+
+ public RegisterExecutor(
+ String appId,
+ String execId,
+ ExecutorShuffleInfo executorInfo) {
+ this.appId = appId;
+ this.execId = execId;
+ this.executorInfo = executorInfo;
+ }
+
+ @Override
+ protected Type type() { return Type.REGISTER_EXECUTOR; }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, execId, executorInfo);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("execId", execId)
+ .add("executorInfo", executorInfo)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof RegisterExecutor) {
+ RegisterExecutor o = (RegisterExecutor) other;
+ return Objects.equal(appId, o.appId)
+ && Objects.equal(execId, o.execId)
+ && Objects.equal(executorInfo, o.executorInfo);
+ }
+ return false;
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId)
+ + Encoders.Strings.encodedLength(execId)
+ + executorInfo.encodedLength();
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ Encoders.Strings.encode(buf, execId);
+ executorInfo.encode(buf);
+ }
+
+ public static RegisterExecutor decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ String execId = Encoders.Strings.decode(buf);
+ ExecutorShuffleInfo executorShuffleInfo = ExecutorShuffleInfo.decode(buf);
+ return new RegisterExecutor(appId, execId, executorShuffleInfo);
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java
index 9c94691224..21369c8cfb 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java
@@ -15,27 +15,30 @@
* limitations under the License.
*/
-package org.apache.spark.network.shuffle;
+package org.apache.spark.network.shuffle.protocol;
import java.io.Serializable;
-import java.util.Arrays;
import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
/**
* Identifier for a fixed number of chunks to read from a stream created by an "open blocks"
- * message. This is used by {@link OneForOneBlockFetcher}.
+ * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}.
*/
-public class ShuffleStreamHandle implements Serializable {
+public class StreamHandle extends BlockTransferMessage {
public final long streamId;
public final int numChunks;
- public ShuffleStreamHandle(long streamId, int numChunks) {
+ public StreamHandle(long streamId, int numChunks) {
this.streamId = streamId;
this.numChunks = numChunks;
}
@Override
+ protected Type type() { return Type.STREAM_HANDLE; }
+
+ @Override
public int hashCode() {
return Objects.hashCode(streamId, numChunks);
}
@@ -50,11 +53,28 @@ public class ShuffleStreamHandle implements Serializable {
@Override
public boolean equals(Object other) {
- if (other != null && other instanceof ShuffleStreamHandle) {
- ShuffleStreamHandle o = (ShuffleStreamHandle) other;
+ if (other != null && other instanceof StreamHandle) {
+ StreamHandle o = (StreamHandle) other;
return Objects.equal(streamId, o.streamId)
&& Objects.equal(numChunks, o.numChunks);
}
return false;
}
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(streamId);
+ buf.writeInt(numChunks);
+ }
+
+ public static StreamHandle decode(ByteBuf buf) {
+ long streamId = buf.readLong();
+ int numChunks = buf.readInt();
+ return new StreamHandle(streamId, numChunks);
+ }
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java
new file mode 100644
index 0000000000..38abe29cc5
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
+public class UploadBlock extends BlockTransferMessage {
+ public final String appId;
+ public final String execId;
+ public final String blockId;
+ // TODO: StorageLevel is serialized separately in here because StorageLevel is not available in
+ // this package. We should avoid this hack.
+ public final byte[] metadata;
+ public final byte[] blockData;
+
+ /**
+ * @param metadata Meta-information about block, typically StorageLevel.
+ * @param blockData The actual block's bytes.
+ */
+ public UploadBlock(
+ String appId,
+ String execId,
+ String blockId,
+ byte[] metadata,
+ byte[] blockData) {
+ this.appId = appId;
+ this.execId = execId;
+ this.blockId = blockId;
+ this.metadata = metadata;
+ this.blockData = blockData;
+ }
+
+ @Override
+ protected Type type() { return Type.UPLOAD_BLOCK; }
+
+ @Override
+ public int hashCode() {
+ int objectsHashCode = Objects.hashCode(appId, execId, blockId);
+ return (objectsHashCode * 41 + Arrays.hashCode(metadata)) * 41 + Arrays.hashCode(blockData);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("execId", execId)
+ .add("blockId", blockId)
+ .add("metadata size", metadata.length)
+ .add("block size", blockData.length)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof UploadBlock) {
+ UploadBlock o = (UploadBlock) other;
+ return Objects.equal(appId, o.appId)
+ && Objects.equal(execId, o.execId)
+ && Objects.equal(blockId, o.blockId)
+ && Arrays.equals(metadata, o.metadata)
+ && Arrays.equals(blockData, o.blockData);
+ }
+ return false;
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId)
+ + Encoders.Strings.encodedLength(execId)
+ + Encoders.Strings.encodedLength(blockId)
+ + Encoders.ByteArrays.encodedLength(metadata)
+ + Encoders.ByteArrays.encodedLength(blockData);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ Encoders.Strings.encode(buf, execId);
+ Encoders.Strings.encode(buf, blockId);
+ Encoders.ByteArrays.encode(buf, metadata);
+ Encoders.ByteArrays.encode(buf, blockData);
+ }
+
+ public static UploadBlock decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ String execId = Encoders.Strings.decode(buf);
+ String blockId = Encoders.Strings.decode(buf);
+ byte[] metadata = Encoders.ByteArrays.decode(buf);
+ byte[] blockData = Encoders.ByteArrays.decode(buf);
+ return new UploadBlock(appId, execId, blockId, metadata, blockData);
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
index ee9482b49c..d65de9ca55 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
@@ -21,31 +21,24 @@ import org.junit.Test;
import static org.junit.Assert.*;
-import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.shuffle.protocol.*;
-import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*;
-
-public class ShuffleMessagesSuite {
+/** Verifies that all BlockTransferMessages can be serialized correctly. */
+public class BlockTransferMessagesSuite {
@Test
public void serializeOpenShuffleBlocks() {
- OpenShuffleBlocks msg = new OpenShuffleBlocks("app-1", "exec-2",
- new String[] { "block0", "block1" });
- OpenShuffleBlocks msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
- assertEquals(msg, msg2);
+ checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" }));
+ checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo(
+ new String[] { "/local1", "/local2" }, 32, "MyShuffleManager")));
+ checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 },
+ new byte[] { 4, 5, 6, 7} ));
+ checkSerializeDeserialize(new StreamHandle(12345, 16));
}
- @Test
- public void serializeRegisterExecutor() {
- RegisterExecutor msg = new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo(
- new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"));
- RegisterExecutor msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
- assertEquals(msg, msg2);
- }
-
- @Test
- public void serializeShuffleStreamHandle() {
- ShuffleStreamHandle msg = new ShuffleStreamHandle(12345, 16);
- ShuffleStreamHandle msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
+ private void checkSerializeDeserialize(BlockTransferMessage msg) {
+ BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray());
assertEquals(msg, msg2);
+ assertEquals(msg.hashCode(), msg2.hashCode());
+ assertEquals(msg.toString(), msg2.toString());
}
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index 7939cb4d32..3f9fe1681c 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -24,8 +24,6 @@ import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
-import static org.apache.spark.network.shuffle.ExternalShuffleMessages.OpenShuffleBlocks;
-import static org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor;
import static org.junit.Assert.*;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.*;
@@ -36,7 +34,12 @@ import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
-import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
+import org.apache.spark.network.shuffle.protocol.UploadBlock;
public class ExternalShuffleBlockHandlerSuite {
TransportClient client = mock(TransportClient.class);
@@ -57,8 +60,7 @@ public class ExternalShuffleBlockHandlerSuite {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort");
- byte[] registerMessage = JavaUtils.serialize(
- new RegisterExecutor("app0", "exec1", config));
+ byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray();
handler.receive(client, registerMessage, callback);
verify(blockManager, times(1)).registerExecutor("app0", "exec1", config);
@@ -75,9 +77,8 @@ public class ExternalShuffleBlockHandlerSuite {
ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
when(blockManager.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker);
when(blockManager.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker);
- byte[] openBlocksMessage = JavaUtils.serialize(
- new OpenShuffleBlocks("app0", "exec1", new String[] { "b0", "b1" }));
- handler.receive(client, openBlocksMessage, callback);
+ byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray();
+ handler.receive(client, openBlocks, callback);
verify(blockManager, times(1)).getBlockData("app0", "exec1", "b0");
verify(blockManager, times(1)).getBlockData("app0", "exec1", "b1");
@@ -85,7 +86,8 @@ public class ExternalShuffleBlockHandlerSuite {
verify(callback, times(1)).onSuccess(response.capture());
verify(callback, never()).onFailure((Throwable) any());
- ShuffleStreamHandle handle = JavaUtils.deserialize(response.getValue());
+ StreamHandle handle =
+ (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue());
assertEquals(2, handle.numChunks);
ArgumentCaptor<Iterator> stream = ArgumentCaptor.forClass(Iterator.class);
@@ -100,18 +102,17 @@ public class ExternalShuffleBlockHandlerSuite {
public void testBadMessages() {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
- byte[] unserializableMessage = new byte[] { 0x12, 0x34, 0x56 };
+ byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 };
try {
- handler.receive(client, unserializableMessage, callback);
+ handler.receive(client, unserializableMsg, callback);
fail("Should have thrown");
} catch (Exception e) {
// pass
}
- byte[] unexpectedMessage = JavaUtils.serialize(
- new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"));
+ byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray();
try {
- handler.receive(client, unexpectedMessage, callback);
+ handler.receive(client, unexpectedMsg, callback);
fail("Should have thrown");
} catch (UnsupportedOperationException e) {
// pass
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index 3bea5b0f25..687bde59fd 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -42,6 +42,7 @@ import org.apache.spark.network.TransportContext;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
index 848c88f743..8afceab1d5 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
@@ -31,6 +31,7 @@ import org.apache.spark.network.sasl.SaslRpcHandler;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
index c18346f696..842741e3d3 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
@@ -40,7 +40,9 @@ import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
public class OneForOneBlockFetcherSuite {
@Test
@@ -119,17 +121,19 @@ public class OneForOneBlockFetcherSuite {
private BlockFetchingListener fetchBlocks(final LinkedHashMap<String, ManagedBuffer> blocks) {
TransportClient client = mock(TransportClient.class);
BlockFetchingListener listener = mock(BlockFetchingListener.class);
- String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
- OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client, blockIds, listener);
+ final String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
+ OneForOneBlockFetcher fetcher =
+ new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener);
// Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
- String message = JavaUtils.deserialize((byte[]) invocationOnMock.getArguments()[0]);
+ BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray(
+ (byte[]) invocationOnMock.getArguments()[0]);
RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1];
- callback.onSuccess(JavaUtils.serialize(new ShuffleStreamHandle(123, blocks.size())));
- assertEquals("OpenZeBlocks", message);
+ callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray());
+ assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message);
return null;
}
}).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any());
@@ -161,7 +165,7 @@ public class OneForOneBlockFetcherSuite {
}
}).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any());
- fetcher.start("OpenZeBlocks");
+ fetcher.start();
return listener;
}
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
index 337b5c7bdb..76639114df 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
@@ -25,6 +25,8 @@ import java.io.OutputStream;
import com.google.common.io.Files;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+
/**
* Manages some sort- and hash-based shuffle data, including the creation
* and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}.