aboutsummaryrefslogtreecommitdiff
path: root/network/shuffle
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2015-11-30 17:22:05 -0800
committerAndrew Or <andrew@databricks.com>2015-11-30 17:22:05 -0800
commit9bf2120672ae0f620a217ccd96bef189ab75e0d6 (patch)
tree6ad03e91cd0c3679e13e41274cdac07c002432b7 /network/shuffle
parent0a46e4377216a1f7de478f220c3b3042a77789e2 (diff)
downloadspark-9bf2120672ae0f620a217ccd96bef189ab75e0d6.tar.gz
spark-9bf2120672ae0f620a217ccd96bef189ab75e0d6.tar.bz2
spark-9bf2120672ae0f620a217ccd96bef189ab75e0d6.zip
[SPARK-12007][NETWORK] Avoid copies in the network lib's RPC layer.
This change seems large, but most of it is just replacing `byte[]` with `ByteBuffer` and `new byte[]` with `ByteBuffer.allocate()`, since it changes the network library's API. The following are parts of the code that actually have meaningful changes: - The Message implementations were changed to inherit from a new AbstractMessage that can optionally hold a reference to a body (in the form of a ManagedBuffer); this is similar to how ResponseWithBody worked before, except now it's not restricted to just responses. - The TransportFrameDecoder was pretty much rewritten to avoid copies as much as possible; it doesn't rely on CompositeByteBuf to accumulate incoming data anymore, since CompositeByteBuf has issues when slices are retained. The code now is able to create frames without having to resort to copying bytes except for a few bytes (containing the frame length) in very rare cases. - Some minor changes in the SASL layer to convert things back to `byte[]` since the JDK SASL API operates on those. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #9987 from vanzin/SPARK-12007.
Diffstat (limited to 'network/shuffle')
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java9
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java3
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java7
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java5
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java8
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java18
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java2
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java21
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java8
9 files changed, 45 insertions, 36 deletions
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 3ddf5c3c39..f22187a01d 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -19,6 +19,7 @@ package org.apache.spark.network.shuffle;
import java.io.File;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.List;
import com.google.common.annotations.VisibleForTesting;
@@ -66,8 +67,8 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
}
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
- BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message);
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
+ BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message);
handleMessage(msgObj, client, callback);
}
@@ -85,13 +86,13 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
}
long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator());
logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
- callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray());
+ callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer());
} else if (msgObj instanceof RegisterExecutor) {
RegisterExecutor msg = (RegisterExecutor) msgObj;
checkAuth(client, msg.appId);
blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
- callback.onSuccess(new byte[0]);
+ callback.onSuccess(ByteBuffer.wrap(new byte[0]));
} else {
throw new UnsupportedOperationException("Unexpected message: " + msgObj);
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index ef3a9dcc87..58ca87d9d3 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -18,6 +18,7 @@
package org.apache.spark.network.shuffle;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.List;
import com.google.common.base.Preconditions;
@@ -139,7 +140,7 @@ public class ExternalShuffleClient extends ShuffleClient {
checkInit();
TransportClient client = clientFactory.createUnmanagedClient(host, port);
try {
- byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray();
+ ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer();
client.sendRpcSync(registerMessage, 5000 /* timeoutMs */);
} finally {
client.close();
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index e653f5cb14..1b2ddbf1ed 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -17,6 +17,7 @@
package org.apache.spark.network.shuffle;
+import java.nio.ByteBuffer;
import java.util.Arrays;
import org.slf4j.Logger;
@@ -89,11 +90,11 @@ public class OneForOneBlockFetcher {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
- client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() {
+ client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
@Override
- public void onSuccess(byte[] response) {
+ public void onSuccess(ByteBuffer response) {
try {
- streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);
+ streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);
// Immediately request all chunks -- we expect that the total size of the request is
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
index 7543b6be4f..675820308b 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
@@ -18,6 +18,7 @@
package org.apache.spark.network.shuffle.mesos;
import java.io.IOException;
+import java.nio.ByteBuffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -54,11 +55,11 @@ public class MesosExternalShuffleClient extends ExternalShuffleClient {
public void registerDriverWithShuffleService(String host, int port) throws IOException {
checkInit();
- byte[] registerDriver = new RegisterDriver(appId).toByteArray();
+ ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer();
TransportClient client = clientFactory.createClient(host, port);
client.sendRpc(registerDriver, new RpcResponseCallback() {
@Override
- public void onSuccess(byte[] response) {
+ public void onSuccess(ByteBuffer response) {
logger.info("Successfully registered app " + appId + " with external shuffle service.");
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
index fcb52363e6..7fbe3384b4 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.shuffle.protocol;
+import java.nio.ByteBuffer;
+
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
@@ -53,7 +55,7 @@ public abstract class BlockTransferMessage implements Encodable {
// NB: Java does not support static methods in interfaces, so we must put this in a static class.
public static class Decoder {
/** Deserializes the 'type' byte followed by the message itself. */
- public static BlockTransferMessage fromByteArray(byte[] msg) {
+ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) {
ByteBuf buf = Unpooled.wrappedBuffer(msg);
byte type = buf.readByte();
switch (type) {
@@ -68,12 +70,12 @@ public abstract class BlockTransferMessage implements Encodable {
}
/** Serializes the 'type' byte followed by the message itself. */
- public byte[] toByteArray() {
+ public ByteBuffer toByteBuffer() {
// Allow room for encoded message, plus the type byte
ByteBuf buf = Unpooled.buffer(encodedLength() + 1);
buf.writeByte(type().id);
encode(buf);
assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes();
- return buf.array();
+ return buf.nioBuffer();
}
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index 1c2fa4d0d4..19c870aebb 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.network.sasl;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicReference;
@@ -52,6 +53,7 @@ import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
@@ -107,8 +109,8 @@ public class SaslIntegrationSuite {
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
String msg = "Hello, World!";
- byte[] resp = client.sendRpcSync(msg.getBytes(), TIMEOUT_MS);
- assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg
+ ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), TIMEOUT_MS);
+ assertEquals(msg, JavaUtils.bytesToString(resp));
}
@Test
@@ -136,7 +138,7 @@ public class SaslIntegrationSuite {
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
try {
- client.sendRpcSync(new byte[13], TIMEOUT_MS);
+ client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS);
fail("Should have failed");
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage"));
@@ -144,7 +146,7 @@ public class SaslIntegrationSuite {
try {
// Guessing the right tag byte doesn't magically get you in...
- client.sendRpcSync(new byte[] { (byte) 0xEA }, TIMEOUT_MS);
+ client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS);
fail("Should have failed");
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException"));
@@ -222,13 +224,13 @@ public class SaslIntegrationSuite {
new String[] { System.getProperty("java.io.tmpdir") }, 1,
"org.apache.spark.shuffle.sort.SortShuffleManager");
RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
- client1.sendRpcSync(regmsg.toByteArray(), TIMEOUT_MS);
+ client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
// Make a successful request to fetch blocks, which creates a new stream. But do not actually
// fetch any blocks, to keep the stream open.
OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
- byte[] response = client1.sendRpcSync(openMessage.toByteArray(), TIMEOUT_MS);
- StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);
+ ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS);
+ StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
long streamId = stream.streamId;
// Create a second client, authenticated with a different app ID, and try to read from
@@ -275,7 +277,7 @@ public class SaslIntegrationSuite {
/** RPC handler which simply responds with the message it received. */
public static class TestRpcHandler extends RpcHandler {
@Override
- public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
callback.onSuccess(message);
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
index d65de9ca55..86c8609e70 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
@@ -36,7 +36,7 @@ public class BlockTransferMessagesSuite {
}
private void checkSerializeDeserialize(BlockTransferMessage msg) {
- BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray());
+ BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer());
assertEquals(msg, msg2);
assertEquals(msg.hashCode(), msg2.hashCode());
assertEquals(msg.toString(), msg2.toString());
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index e61390cf57..9379412155 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -60,12 +60,12 @@ public class ExternalShuffleBlockHandlerSuite {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort");
- byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray();
+ ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer();
handler.receive(client, registerMessage, callback);
verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config);
- verify(callback, times(1)).onSuccess((byte[]) any());
- verify(callback, never()).onFailure((Throwable) any());
+ verify(callback, times(1)).onSuccess(any(ByteBuffer.class));
+ verify(callback, never()).onFailure(any(Throwable.class));
}
@SuppressWarnings("unchecked")
@@ -77,17 +77,18 @@ public class ExternalShuffleBlockHandlerSuite {
ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker);
when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker);
- byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray();
+ ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" })
+ .toByteBuffer();
handler.receive(client, openBlocks, callback);
verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0");
verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1");
- ArgumentCaptor<byte[]> response = ArgumentCaptor.forClass(byte[].class);
+ ArgumentCaptor<ByteBuffer> response = ArgumentCaptor.forClass(ByteBuffer.class);
verify(callback, times(1)).onSuccess(response.capture());
verify(callback, never()).onFailure((Throwable) any());
StreamHandle handle =
- (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue());
+ (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue());
assertEquals(2, handle.numChunks);
@SuppressWarnings("unchecked")
@@ -104,7 +105,7 @@ public class ExternalShuffleBlockHandlerSuite {
public void testBadMessages() {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
- byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 };
+ ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 });
try {
handler.receive(client, unserializableMsg, callback);
fail("Should have thrown");
@@ -112,7 +113,7 @@ public class ExternalShuffleBlockHandlerSuite {
// pass
}
- byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray();
+ ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer();
try {
handler.receive(client, unexpectedMsg, callback);
fail("Should have thrown");
@@ -120,7 +121,7 @@ public class ExternalShuffleBlockHandlerSuite {
// pass
}
- verify(callback, never()).onSuccess((byte[]) any());
- verify(callback, never()).onFailure((Throwable) any());
+ verify(callback, never()).onSuccess(any(ByteBuffer.class));
+ verify(callback, never()).onFailure(any(Throwable.class));
}
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
index b35a6d685d..2590b9ce4c 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
@@ -134,14 +134,14 @@ public class OneForOneBlockFetcherSuite {
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
- BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray(
- (byte[]) invocationOnMock.getArguments()[0]);
+ BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer(
+ (ByteBuffer) invocationOnMock.getArguments()[0]);
RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1];
- callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray());
+ callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer());
assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message);
return null;
}
- }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any());
+ }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class));
// Respond to each chunk request with a single buffer from our blocks array.
final AtomicInteger expectedChunkIndex = new AtomicInteger(0);