diff options
Diffstat (limited to 'network/shuffle')
18 files changed, 501 insertions, 201 deletions
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index 599cc6428c..cad76ab7aa 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -17,10 +17,10 @@ package org.apache.spark.network.sasl; -import com.google.common.base.Charsets; import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; /** * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged @@ -42,18 +42,14 @@ class SaslMessage implements Encodable { @Override public int encodedLength() { - // tag + appIdLength + appId + payloadLength + payload - return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length; + return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload); } @Override public void encode(ByteBuf buf) { buf.writeByte(TAG_BYTE); - byte[] idBytes = appId.getBytes(Charsets.UTF_8); - buf.writeInt(idBytes.length); - buf.writeBytes(idBytes); - buf.writeInt(payload.length); - buf.writeBytes(payload); + Encoders.Strings.encode(buf, appId); + Encoders.ByteArrays.encode(buf, payload); } public static SaslMessage decode(ByteBuf buf) { @@ -62,14 +58,8 @@ class SaslMessage implements Encodable { + " (maybe your client does not have SASL enabled?)"); } - int idLength = buf.readInt(); - byte[] idBytes = new byte[idLength]; - buf.readBytes(idBytes); - - int payloadLength = buf.readInt(); - byte[] payload = new byte[payloadLength]; - buf.readBytes(payload); - - return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload); + String appId = Encoders.Strings.decode(buf); + byte[] payload = Encoders.ByteArrays.decode(buf); + return new SaslMessage(appId, payload); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 75ebf8c7b0..a6db4b2abd 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -24,15 +24,16 @@ import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*; - import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; /** * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. @@ -62,12 +63,10 @@ public class ExternalShuffleBlockHandler extends RpcHandler { @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - Object msgObj = JavaUtils.deserialize(message); - - logger.trace("Received message: " + msgObj); + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message); - if (msgObj instanceof OpenShuffleBlocks) { - OpenShuffleBlocks msg = (OpenShuffleBlocks) msgObj; + if (msgObj instanceof OpenBlocks) { + OpenBlocks msg = (OpenBlocks) msgObj; List<ManagedBuffer> blocks = Lists.newArrayList(); for (String blockId : msg.blockIds) { @@ -75,8 +74,7 @@ public class ExternalShuffleBlockHandler extends RpcHandler { } long streamId = streamManager.registerStream(blocks.iterator()); logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); - callback.onSuccess(JavaUtils.serialize( - new ShuffleStreamHandle(streamId, msg.blockIds.length))); + callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray()); } else if (msgObj instanceof RegisterExecutor) { RegisterExecutor msg = (RegisterExecutor) msgObj; @@ -84,8 +82,7 @@ public class ExternalShuffleBlockHandler extends RpcHandler { callback.onSuccess(new byte[0]); } else { - throw new UnsupportedOperationException(String.format( - "Unexpected message: %s (class = %s)", msgObj, msgObj.getClass())); + throw new UnsupportedOperationException("Unexpected message: " + msgObj); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java index 98fcfb82aa..ffb7faa3db 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java @@ -35,6 +35,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.JavaUtils; /** diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 27884b82c8..6e8018b723 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -31,8 +31,8 @@ import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.sasl.SaslClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; -import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; import org.apache.spark.network.util.TransportConf; /** @@ -91,8 +91,7 @@ public class ExternalShuffleClient extends ShuffleClient { public void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException { TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, blockIds, listener) - .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds)); + new OneForOneBlockFetcher(client, appId, execId, blockIds, listener).start(); } }; @@ -128,9 +127,8 @@ public class ExternalShuffleClient extends ShuffleClient { ExecutorShuffleInfo executorInfo) throws IOException { assert appId != null : "Called before init()"; TransportClient client = clientFactory.createClient(host, port); - byte[] registerExecutorMessage = - JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo)); - client.sendRpcSync(registerExecutorMessage, 5000 /* timeoutMs */); + byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); + client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); } @Override diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java deleted file mode 100644 index e79420ed82..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.Serializable; -import java.util.Arrays; - -import com.google.common.base.Objects; - -/** Messages handled by the {@link ExternalShuffleBlockHandler}. */ -public class ExternalShuffleMessages { - - /** Request to read a set of shuffle blocks. Returns [[ShuffleStreamHandle]]. */ - public static class OpenShuffleBlocks implements Serializable { - public final String appId; - public final String execId; - public final String[] blockIds; - - public OpenShuffleBlocks(String appId, String execId, String[] blockIds) { - this.appId = appId; - this.execId = execId; - this.blockIds = blockIds; - } - - @Override - public int hashCode() { - return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("appId", appId) - .add("execId", execId) - .add("blockIds", Arrays.toString(blockIds)) - .toString(); - } - - @Override - public boolean equals(Object other) { - if (other != null && other instanceof OpenShuffleBlocks) { - OpenShuffleBlocks o = (OpenShuffleBlocks) other; - return Objects.equal(appId, o.appId) - && Objects.equal(execId, o.execId) - && Arrays.equals(blockIds, o.blockIds); - } - return false; - } - } - - /** Initial registration message between an executor and its local shuffle server. */ - public static class RegisterExecutor implements Serializable { - public final String appId; - public final String execId; - public final ExecutorShuffleInfo executorInfo; - - public RegisterExecutor( - String appId, - String execId, - ExecutorShuffleInfo executorInfo) { - this.appId = appId; - this.execId = execId; - this.executorInfo = executorInfo; - } - - @Override - public int hashCode() { - return Objects.hashCode(appId, execId, executorInfo); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("appId", appId) - .add("execId", execId) - .add("executorInfo", executorInfo) - .toString(); - } - - @Override - public boolean equals(Object other) { - if (other != null && other instanceof RegisterExecutor) { - RegisterExecutor o = (RegisterExecutor) other; - return Objects.equal(appId, o.appId) - && Objects.equal(execId, o.execId) - && Objects.equal(executorInfo, o.executorInfo); - } - return false; - } - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 9e77a1f68c..8ed2e0b39a 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -26,6 +26,9 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.util.JavaUtils; /** @@ -41,17 +44,21 @@ public class OneForOneBlockFetcher { private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class); private final TransportClient client; + private final OpenBlocks openMessage; private final String[] blockIds; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; - private ShuffleStreamHandle streamHandle = null; + private StreamHandle streamHandle = null; public OneForOneBlockFetcher( TransportClient client, + String appId, + String execId, String[] blockIds, BlockFetchingListener listener) { this.client = client; + this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); @@ -76,18 +83,18 @@ public class OneForOneBlockFetcher { /** * Begins the fetching process, calling the listener with every block fetched. * The given message will be serialized with the Java serializer, and the RPC must return a - * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling. + * {@link StreamHandle}. We will send all fetch requests immediately, without throttling. */ - public void start(Object openBlocksMessage) { + public void start() { if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() { + client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() { @Override public void onSuccess(byte[] response) { try { - streamHandle = JavaUtils.deserialize(response); + streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); // Immediately request all chunks -- we expect that the total size of the request is diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java new file mode 100644 index 0000000000..b4b13b8a6e --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or + * by Spark's NettyBlockTransferService. + * + * At a high level: + * - OpenBlock is handled by both services, but only services shuffle files for the external + * shuffle service. It returns a StreamHandle. + * - UploadBlock is only handled by the NettyBlockTransferService. + * - RegisterExecutor is only handled by the external shuffle service. + */ +public abstract class BlockTransferMessage implements Encodable { + protected abstract Type type(); + + /** Preceding every serialized message is its type, which allows us to deserialize it. */ + public static enum Type { + OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 message types"; + this.id = (byte) id; + } + + public byte id() { return id; } + } + + // NB: Java does not support static methods in interfaces, so we must put this in a static class. + public static class Decoder { + /** Deserializes the 'type' byte followed by the message itself. */ + public static BlockTransferMessage fromByteArray(byte[] msg) { + ByteBuf buf = Unpooled.wrappedBuffer(msg); + byte type = buf.readByte(); + switch (type) { + case 0: return OpenBlocks.decode(buf); + case 1: return UploadBlock.decode(buf); + case 2: return RegisterExecutor.decode(buf); + case 3: return StreamHandle.decode(buf); + default: throw new IllegalArgumentException("Unknown message type: " + type); + } + } + } + + /** Serializes the 'type' byte followed by the message itself. */ + public byte[] toByteArray() { + ByteBuf buf = Unpooled.buffer(encodedLength()); + buf.writeByte(type().id); + encode(buf); + assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); + return buf.array(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java index d45e64656a..cadc8e8369 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -15,21 +15,24 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle; +package org.apache.spark.network.shuffle.protocol; -import java.io.Serializable; import java.util.Arrays; import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; /** Contains all configuration necessary for locating the shuffle files of an executor. */ -public class ExecutorShuffleInfo implements Serializable { +public class ExecutorShuffleInfo implements Encodable { /** The base set of local directories that the executor stores its shuffle files in. */ - final String[] localDirs; + public final String[] localDirs; /** Number of subdirectories created within each localDir. */ - final int subDirsPerLocalDir; + public final int subDirsPerLocalDir; /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */ - final String shuffleManager; + public final String shuffleManager; public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) { this.localDirs = localDirs; @@ -61,4 +64,25 @@ public class ExecutorShuffleInfo implements Serializable { } return false; } + + @Override + public int encodedLength() { + return Encoders.StringArrays.encodedLength(localDirs) + + 4 // int + + Encoders.Strings.encodedLength(shuffleManager); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.StringArrays.encode(buf, localDirs); + buf.writeInt(subDirsPerLocalDir); + Encoders.Strings.encode(buf, shuffleManager); + } + + public static ExecutorShuffleInfo decode(ByteBuf buf) { + String[] localDirs = Encoders.StringArrays.decode(buf); + int subDirsPerLocalDir = buf.readInt(); + String shuffleManager = Encoders.Strings.decode(buf); + return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java new file mode 100644 index 0000000000..60485bace6 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** Request to read a set of blocks. Returns {@link StreamHandle}. */ +public class OpenBlocks extends BlockTransferMessage { + public final String appId; + public final String execId; + public final String[] blockIds; + + public OpenBlocks(String appId, String execId, String[] blockIds) { + this.appId = appId; + this.execId = execId; + this.blockIds = blockIds; + } + + @Override + protected Type type() { return Type.OPEN_BLOCKS; } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockIds", Arrays.toString(blockIds)) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof OpenBlocks) { + OpenBlocks o = (OpenBlocks) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Arrays.equals(blockIds, o.blockIds); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.StringArrays.encodedLength(blockIds); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.StringArrays.encode(buf, blockIds); + } + + public static OpenBlocks decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String[] blockIds = Encoders.StringArrays.decode(buf); + return new OpenBlocks(appId, execId, blockIds); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java new file mode 100644 index 0000000000..38acae3b31 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** + * Initial registration message between an executor and its local shuffle server. + * Returns nothing (empty bye array). + */ +public class RegisterExecutor extends BlockTransferMessage { + public final String appId; + public final String execId; + public final ExecutorShuffleInfo executorInfo; + + public RegisterExecutor( + String appId, + String execId, + ExecutorShuffleInfo executorInfo) { + this.appId = appId; + this.execId = execId; + this.executorInfo = executorInfo; + } + + @Override + protected Type type() { return Type.REGISTER_EXECUTOR; } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId, executorInfo); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("executorInfo", executorInfo) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof RegisterExecutor) { + RegisterExecutor o = (RegisterExecutor) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(executorInfo, o.executorInfo); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + executorInfo.encodedLength(); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + executorInfo.encode(buf); + } + + public static RegisterExecutor decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + ExecutorShuffleInfo executorShuffleInfo = ExecutorShuffleInfo.decode(buf); + return new RegisterExecutor(appId, execId, executorShuffleInfo); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java index 9c94691224..21369c8cfb 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java @@ -15,27 +15,30 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle; +package org.apache.spark.network.shuffle.protocol; import java.io.Serializable; -import java.util.Arrays; import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; /** * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" - * message. This is used by {@link OneForOneBlockFetcher}. + * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}. */ -public class ShuffleStreamHandle implements Serializable { +public class StreamHandle extends BlockTransferMessage { public final long streamId; public final int numChunks; - public ShuffleStreamHandle(long streamId, int numChunks) { + public StreamHandle(long streamId, int numChunks) { this.streamId = streamId; this.numChunks = numChunks; } @Override + protected Type type() { return Type.STREAM_HANDLE; } + + @Override public int hashCode() { return Objects.hashCode(streamId, numChunks); } @@ -50,11 +53,28 @@ public class ShuffleStreamHandle implements Serializable { @Override public boolean equals(Object other) { - if (other != null && other instanceof ShuffleStreamHandle) { - ShuffleStreamHandle o = (ShuffleStreamHandle) other; + if (other != null && other instanceof StreamHandle) { + StreamHandle o = (StreamHandle) other; return Objects.equal(streamId, o.streamId) && Objects.equal(numChunks, o.numChunks); } return false; } + + @Override + public int encodedLength() { + return 8 + 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(streamId); + buf.writeInt(numChunks); + } + + public static StreamHandle decode(ByteBuf buf) { + long streamId = buf.readLong(); + int numChunks = buf.readInt(); + return new StreamHandle(streamId, numChunks); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java new file mode 100644 index 0000000000..38abe29cc5 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ +public class UploadBlock extends BlockTransferMessage { + public final String appId; + public final String execId; + public final String blockId; + // TODO: StorageLevel is serialized separately in here because StorageLevel is not available in + // this package. We should avoid this hack. + public final byte[] metadata; + public final byte[] blockData; + + /** + * @param metadata Meta-information about block, typically StorageLevel. + * @param blockData The actual block's bytes. + */ + public UploadBlock( + String appId, + String execId, + String blockId, + byte[] metadata, + byte[] blockData) { + this.appId = appId; + this.execId = execId; + this.blockId = blockId; + this.metadata = metadata; + this.blockData = blockData; + } + + @Override + protected Type type() { return Type.UPLOAD_BLOCK; } + + @Override + public int hashCode() { + int objectsHashCode = Objects.hashCode(appId, execId, blockId); + return (objectsHashCode * 41 + Arrays.hashCode(metadata)) * 41 + Arrays.hashCode(blockData); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockId", blockId) + .add("metadata size", metadata.length) + .add("block size", blockData.length) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadBlock) { + UploadBlock o = (UploadBlock) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(blockId, o.blockId) + && Arrays.equals(metadata, o.metadata) + && Arrays.equals(blockData, o.blockData); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.Strings.encodedLength(blockId) + + Encoders.ByteArrays.encodedLength(metadata) + + Encoders.ByteArrays.encodedLength(blockData); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.Strings.encode(buf, blockId); + Encoders.ByteArrays.encode(buf, metadata); + Encoders.ByteArrays.encode(buf, blockData); + } + + public static UploadBlock decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String blockId = Encoders.Strings.decode(buf); + byte[] metadata = Encoders.ByteArrays.decode(buf); + byte[] blockData = Encoders.ByteArrays.decode(buf); + return new UploadBlock(appId, execId, blockId, metadata, blockData); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index ee9482b49c..d65de9ca55 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -21,31 +21,24 @@ import org.junit.Test; import static org.junit.Assert.*; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.*; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*; - -public class ShuffleMessagesSuite { +/** Verifies that all BlockTransferMessages can be serialized correctly. */ +public class BlockTransferMessagesSuite { @Test public void serializeOpenShuffleBlocks() { - OpenShuffleBlocks msg = new OpenShuffleBlocks("app-1", "exec-2", - new String[] { "block0", "block1" }); - OpenShuffleBlocks msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); - assertEquals(msg, msg2); + checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" })); + checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( + new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"))); + checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 }, + new byte[] { 4, 5, 6, 7} )); + checkSerializeDeserialize(new StreamHandle(12345, 16)); } - @Test - public void serializeRegisterExecutor() { - RegisterExecutor msg = new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( - new String[] { "/local1", "/local2" }, 32, "MyShuffleManager")); - RegisterExecutor msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); - assertEquals(msg, msg2); - } - - @Test - public void serializeShuffleStreamHandle() { - ShuffleStreamHandle msg = new ShuffleStreamHandle(12345, 16); - ShuffleStreamHandle msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); + private void checkSerializeDeserialize(BlockTransferMessage msg) { + BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray()); assertEquals(msg, msg2); + assertEquals(msg.hashCode(), msg2.hashCode()); + assertEquals(msg.toString(), msg2.toString()); } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 7939cb4d32..3f9fe1681c 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -24,8 +24,6 @@ import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.OpenShuffleBlocks; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; import static org.junit.Assert.*; import static org.mockito.Matchers.any; import static org.mockito.Mockito.*; @@ -36,7 +34,12 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.shuffle.protocol.UploadBlock; public class ExternalShuffleBlockHandlerSuite { TransportClient client = mock(TransportClient.class); @@ -57,8 +60,7 @@ public class ExternalShuffleBlockHandlerSuite { RpcResponseCallback callback = mock(RpcResponseCallback.class); ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); - byte[] registerMessage = JavaUtils.serialize( - new RegisterExecutor("app0", "exec1", config)); + byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray(); handler.receive(client, registerMessage, callback); verify(blockManager, times(1)).registerExecutor("app0", "exec1", config); @@ -75,9 +77,8 @@ public class ExternalShuffleBlockHandlerSuite { ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); when(blockManager.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); when(blockManager.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - byte[] openBlocksMessage = JavaUtils.serialize( - new OpenShuffleBlocks("app0", "exec1", new String[] { "b0", "b1" })); - handler.receive(client, openBlocksMessage, callback); + byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray(); + handler.receive(client, openBlocks, callback); verify(blockManager, times(1)).getBlockData("app0", "exec1", "b0"); verify(blockManager, times(1)).getBlockData("app0", "exec1", "b1"); @@ -85,7 +86,8 @@ public class ExternalShuffleBlockHandlerSuite { verify(callback, times(1)).onSuccess(response.capture()); verify(callback, never()).onFailure((Throwable) any()); - ShuffleStreamHandle handle = JavaUtils.deserialize(response.getValue()); + StreamHandle handle = + (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue()); assertEquals(2, handle.numChunks); ArgumentCaptor<Iterator> stream = ArgumentCaptor.forClass(Iterator.class); @@ -100,18 +102,17 @@ public class ExternalShuffleBlockHandlerSuite { public void testBadMessages() { RpcResponseCallback callback = mock(RpcResponseCallback.class); - byte[] unserializableMessage = new byte[] { 0x12, 0x34, 0x56 }; + byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 }; try { - handler.receive(client, unserializableMessage, callback); + handler.receive(client, unserializableMsg, callback); fail("Should have thrown"); } catch (Exception e) { // pass } - byte[] unexpectedMessage = JavaUtils.serialize( - new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort")); + byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray(); try { - handler.receive(client, unexpectedMessage, callback); + handler.receive(client, unexpectedMsg, callback); fail("Should have thrown"); } catch (UnsupportedOperationException e) { // pass diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 3bea5b0f25..687bde59fd 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -42,6 +42,7 @@ import org.apache.spark.network.TransportContext; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 848c88f743..8afceab1d5 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -31,6 +31,7 @@ import org.apache.spark.network.sasl.SaslRpcHandler; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index c18346f696..842741e3d3 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -40,7 +40,9 @@ import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.StreamHandle; public class OneForOneBlockFetcherSuite { @Test @@ -119,17 +121,19 @@ public class OneForOneBlockFetcherSuite { private BlockFetchingListener fetchBlocks(final LinkedHashMap<String, ManagedBuffer> blocks) { TransportClient client = mock(TransportClient.class); BlockFetchingListener listener = mock(BlockFetchingListener.class); - String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); - OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client, blockIds, listener); + final String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + OneForOneBlockFetcher fetcher = + new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener); // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123 doAnswer(new Answer<Void>() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - String message = JavaUtils.deserialize((byte[]) invocationOnMock.getArguments()[0]); + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray( + (byte[]) invocationOnMock.getArguments()[0]); RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; - callback.onSuccess(JavaUtils.serialize(new ShuffleStreamHandle(123, blocks.size()))); - assertEquals("OpenZeBlocks", message); + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray()); + assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); return null; } }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any()); @@ -161,7 +165,7 @@ public class OneForOneBlockFetcherSuite { } }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any()); - fetcher.start("OpenZeBlocks"); + fetcher.start(); return listener; } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 337b5c7bdb..76639114df 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -25,6 +25,8 @@ import java.io.OutputStream; import com.google.common.io.Files; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; + /** * Manages some sort- and hash-based shuffle data, including the creation * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}. |