aboutsummaryrefslogtreecommitdiff
path: root/network/common
diff options
context:
space:
mode:
authorAaron Davidson <aaron@databricks.com>2014-11-07 09:42:21 -0800
committerReynold Xin <rxin@databricks.com>2014-11-07 09:42:21 -0800
commitd4fa04e50d299e9cad349b3781772956453a696b (patch)
treed561c5ac49c49b56aba2fa7964587a9e72c253a9 /network/common
parent3abdb1b24aa48f21e7eed1232c01d3933873688c (diff)
downloadspark-d4fa04e50d299e9cad349b3781772956453a696b.tar.gz
spark-d4fa04e50d299e9cad349b3781772956453a696b.tar.bz2
spark-d4fa04e50d299e9cad349b3781772956453a696b.zip
[SPARK-4187] [Core] Switch to binary protocol for external shuffle service messages
This PR elimiantes the network package's usage of the Java serializer and replaces it with Encodable, which is a lightweight binary protocol. Each message is preceded by a type id, which will allow us to change messages (by only adding new ones), or to change the format entirely by switching to a special id (such as -1). This protocol has the advantage over Java that we can guarantee that messages will remain compatible across compiled versions and JVMs, though it does not provide a clean way to do schema migration. In the future, it may be good to use a more heavy-weight serialization format like protobuf, thrift, or avro, but these all add several dependencies which are unnecessary at the present time. Additionally this unifies the RPC messages of NettyBlockTransferService and ExternalShuffleClient. Author: Aaron Davidson <aaron@databricks.com> Closes #3146 from aarondav/free and squashes the following commits: ed1102a [Aaron Davidson] Remove some unused imports b8e2a49 [Aaron Davidson] Add appId to test 538f2a3 [Aaron Davidson] [SPARK-4187] [Core] Switch to binary protocol for external shuffle service messages
Diffstat (limited to 'network/common')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java12
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java93
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java12
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java27
6 files changed, 107 insertions, 55 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
index 152af98ced..986957c150 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -38,23 +38,19 @@ public final class ChunkFetchFailure implements ResponseMessage {
@Override
public int encodedLength() {
- return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString);
}
@Override
public void encode(ByteBuf buf) {
streamChunkId.encode(buf);
- byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
- buf.writeInt(errorBytes.length);
- buf.writeBytes(errorBytes);
+ Encoders.Strings.encode(buf, errorString);
}
public static ChunkFetchFailure decode(ByteBuf buf) {
StreamChunkId streamChunkId = StreamChunkId.decode(buf);
- int numErrorStringBytes = buf.readInt();
- byte[] errorBytes = new byte[numErrorStringBytes];
- buf.readBytes(errorBytes);
- return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8));
+ String errorString = Encoders.Strings.decode(buf);
+ return new ChunkFetchFailure(streamChunkId, errorString);
}
@Override
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java
new file mode 100644
index 0000000000..873c694250
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+
+import com.google.common.base.Charsets;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+/** Provides a canonical set of Encoders for simple types. */
+public class Encoders {
+
+ /** Strings are encoded with their length followed by UTF-8 bytes. */
+ public static class Strings {
+ public static int encodedLength(String s) {
+ return 4 + s.getBytes(Charsets.UTF_8).length;
+ }
+
+ public static void encode(ByteBuf buf, String s) {
+ byte[] bytes = s.getBytes(Charsets.UTF_8);
+ buf.writeInt(bytes.length);
+ buf.writeBytes(bytes);
+ }
+
+ public static String decode(ByteBuf buf) {
+ int length = buf.readInt();
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return new String(bytes, Charsets.UTF_8);
+ }
+ }
+
+ /** Byte arrays are encoded with their length followed by bytes. */
+ public static class ByteArrays {
+ public static int encodedLength(byte[] arr) {
+ return 4 + arr.length;
+ }
+
+ public static void encode(ByteBuf buf, byte[] arr) {
+ buf.writeInt(arr.length);
+ buf.writeBytes(arr);
+ }
+
+ public static byte[] decode(ByteBuf buf) {
+ int length = buf.readInt();
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return bytes;
+ }
+ }
+
+ /** String arrays are encoded with the number of strings followed by per-String encoding. */
+ public static class StringArrays {
+ public static int encodedLength(String[] strings) {
+ int totalLength = 4;
+ for (String s : strings) {
+ totalLength += Strings.encodedLength(s);
+ }
+ return totalLength;
+ }
+
+ public static void encode(ByteBuf buf, String[] strings) {
+ buf.writeInt(strings.length);
+ for (String s : strings) {
+ Strings.encode(buf, s);
+ }
+ }
+
+ public static String[] decode(ByteBuf buf) {
+ int numStrings = buf.readInt();
+ String[] strings = new String[numStrings];
+ for (int i = 0; i < strings.length; i ++) {
+ strings[i] = Strings.decode(buf);
+ }
+ return strings;
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
index e239d4ffbd..ebd764eb5e 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
@@ -36,23 +36,19 @@ public final class RpcFailure implements ResponseMessage {
@Override
public int encodedLength() {
- return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ return 8 + Encoders.Strings.encodedLength(errorString);
}
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
- buf.writeInt(errorBytes.length);
- buf.writeBytes(errorBytes);
+ Encoders.Strings.encode(buf, errorString);
}
public static RpcFailure decode(ByteBuf buf) {
long requestId = buf.readLong();
- int numErrorStringBytes = buf.readInt();
- byte[] errorBytes = new byte[numErrorStringBytes];
- buf.readBytes(errorBytes);
- return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8));
+ String errorString = Encoders.Strings.decode(buf);
+ return new RpcFailure(requestId, errorString);
}
@Override
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
index 099e934ae0..cdee0b0e03 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
@@ -44,21 +44,18 @@ public final class RpcRequest implements RequestMessage {
@Override
public int encodedLength() {
- return 8 + 4 + message.length;
+ return 8 + Encoders.ByteArrays.encodedLength(message);
}
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- buf.writeInt(message.length);
- buf.writeBytes(message);
+ Encoders.ByteArrays.encode(buf, message);
}
public static RpcRequest decode(ByteBuf buf) {
long requestId = buf.readLong();
- int messageLen = buf.readInt();
- byte[] message = new byte[messageLen];
- buf.readBytes(message);
+ byte[] message = Encoders.ByteArrays.decode(buf);
return new RpcRequest(requestId, message);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
index ed47947832..0a62e09a81 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
@@ -36,20 +36,17 @@ public final class RpcResponse implements ResponseMessage {
public Type type() { return Type.RpcResponse; }
@Override
- public int encodedLength() { return 8 + 4 + response.length; }
+ public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); }
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- buf.writeInt(response.length);
- buf.writeBytes(response);
+ Encoders.ByteArrays.encode(buf, response);
}
public static RpcResponse decode(ByteBuf buf) {
long requestId = buf.readLong();
- int responseLen = buf.readInt();
- byte[] response = new byte[responseLen];
- buf.readBytes(response);
+ byte[] response = Encoders.ByteArrays.decode(buf);
return new RpcResponse(requestId, response);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
index 75c4a3981a..009dbcf013 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -50,33 +50,6 @@ public class JavaUtils {
}
}
- // TODO: Make this configurable, do not use Java serialization!
- public static <T> T deserialize(byte[] bytes) {
- try {
- ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes));
- Object out = is.readObject();
- is.close();
- return (T) out;
- } catch (ClassNotFoundException e) {
- throw new RuntimeException("Could not deserialize object", e);
- } catch (IOException e) {
- throw new RuntimeException("Could not deserialize object", e);
- }
- }
-
- // TODO: Make this configurable, do not use Java serialization!
- public static byte[] serialize(Object object) {
- try {
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
- ObjectOutputStream os = new ObjectOutputStream(baos);
- os.writeObject(object);
- os.close();
- return baos.toByteArray();
- } catch (IOException e) {
- throw new RuntimeException("Could not serialize object", e);
- }
- }
-
/** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */
public static int nonNegativeHash(Object obj) {
if (obj == null) { return 0; }