aboutsummaryrefslogtreecommitdiff
path: root/network/common
diff options
context:
space:
mode:
Diffstat (limited to 'network/common')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java12
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java93
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java12
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java27
6 files changed, 107 insertions, 55 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
index 152af98ced..986957c150 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -38,23 +38,19 @@ public final class ChunkFetchFailure implements ResponseMessage {
@Override
public int encodedLength() {
- return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString);
}
@Override
public void encode(ByteBuf buf) {
streamChunkId.encode(buf);
- byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
- buf.writeInt(errorBytes.length);
- buf.writeBytes(errorBytes);
+ Encoders.Strings.encode(buf, errorString);
}
public static ChunkFetchFailure decode(ByteBuf buf) {
StreamChunkId streamChunkId = StreamChunkId.decode(buf);
- int numErrorStringBytes = buf.readInt();
- byte[] errorBytes = new byte[numErrorStringBytes];
- buf.readBytes(errorBytes);
- return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8));
+ String errorString = Encoders.Strings.decode(buf);
+ return new ChunkFetchFailure(streamChunkId, errorString);
}
@Override
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java
new file mode 100644
index 0000000000..873c694250
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+
+import com.google.common.base.Charsets;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+/** Provides a canonical set of Encoders for simple types. */
+public class Encoders {
+
+ /** Strings are encoded with their length followed by UTF-8 bytes. */
+ public static class Strings {
+ public static int encodedLength(String s) {
+ return 4 + s.getBytes(Charsets.UTF_8).length;
+ }
+
+ public static void encode(ByteBuf buf, String s) {
+ byte[] bytes = s.getBytes(Charsets.UTF_8);
+ buf.writeInt(bytes.length);
+ buf.writeBytes(bytes);
+ }
+
+ public static String decode(ByteBuf buf) {
+ int length = buf.readInt();
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return new String(bytes, Charsets.UTF_8);
+ }
+ }
+
+ /** Byte arrays are encoded with their length followed by bytes. */
+ public static class ByteArrays {
+ public static int encodedLength(byte[] arr) {
+ return 4 + arr.length;
+ }
+
+ public static void encode(ByteBuf buf, byte[] arr) {
+ buf.writeInt(arr.length);
+ buf.writeBytes(arr);
+ }
+
+ public static byte[] decode(ByteBuf buf) {
+ int length = buf.readInt();
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return bytes;
+ }
+ }
+
+ /** String arrays are encoded with the number of strings followed by per-String encoding. */
+ public static class StringArrays {
+ public static int encodedLength(String[] strings) {
+ int totalLength = 4;
+ for (String s : strings) {
+ totalLength += Strings.encodedLength(s);
+ }
+ return totalLength;
+ }
+
+ public static void encode(ByteBuf buf, String[] strings) {
+ buf.writeInt(strings.length);
+ for (String s : strings) {
+ Strings.encode(buf, s);
+ }
+ }
+
+ public static String[] decode(ByteBuf buf) {
+ int numStrings = buf.readInt();
+ String[] strings = new String[numStrings];
+ for (int i = 0; i < strings.length; i ++) {
+ strings[i] = Strings.decode(buf);
+ }
+ return strings;
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
index e239d4ffbd..ebd764eb5e 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
@@ -36,23 +36,19 @@ public final class RpcFailure implements ResponseMessage {
@Override
public int encodedLength() {
- return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ return 8 + Encoders.Strings.encodedLength(errorString);
}
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
- buf.writeInt(errorBytes.length);
- buf.writeBytes(errorBytes);
+ Encoders.Strings.encode(buf, errorString);
}
public static RpcFailure decode(ByteBuf buf) {
long requestId = buf.readLong();
- int numErrorStringBytes = buf.readInt();
- byte[] errorBytes = new byte[numErrorStringBytes];
- buf.readBytes(errorBytes);
- return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8));
+ String errorString = Encoders.Strings.decode(buf);
+ return new RpcFailure(requestId, errorString);
}
@Override
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
index 099e934ae0..cdee0b0e03 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
@@ -44,21 +44,18 @@ public final class RpcRequest implements RequestMessage {
@Override
public int encodedLength() {
- return 8 + 4 + message.length;
+ return 8 + Encoders.ByteArrays.encodedLength(message);
}
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- buf.writeInt(message.length);
- buf.writeBytes(message);
+ Encoders.ByteArrays.encode(buf, message);
}
public static RpcRequest decode(ByteBuf buf) {
long requestId = buf.readLong();
- int messageLen = buf.readInt();
- byte[] message = new byte[messageLen];
- buf.readBytes(message);
+ byte[] message = Encoders.ByteArrays.decode(buf);
return new RpcRequest(requestId, message);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
index ed47947832..0a62e09a81 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
@@ -36,20 +36,17 @@ public final class RpcResponse implements ResponseMessage {
public Type type() { return Type.RpcResponse; }
@Override
- public int encodedLength() { return 8 + 4 + response.length; }
+ public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); }
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
- buf.writeInt(response.length);
- buf.writeBytes(response);
+ Encoders.ByteArrays.encode(buf, response);
}
public static RpcResponse decode(ByteBuf buf) {
long requestId = buf.readLong();
- int responseLen = buf.readInt();
- byte[] response = new byte[responseLen];
- buf.readBytes(response);
+ byte[] response = Encoders.ByteArrays.decode(buf);
return new RpcResponse(requestId, response);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
index 75c4a3981a..009dbcf013 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -50,33 +50,6 @@ public class JavaUtils {
}
}
- // TODO: Make this configurable, do not use Java serialization!
- public static <T> T deserialize(byte[] bytes) {
- try {
- ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes));
- Object out = is.readObject();
- is.close();
- return (T) out;
- } catch (ClassNotFoundException e) {
- throw new RuntimeException("Could not deserialize object", e);
- } catch (IOException e) {
- throw new RuntimeException("Could not deserialize object", e);
- }
- }
-
- // TODO: Make this configurable, do not use Java serialization!
- public static byte[] serialize(Object object) {
- try {
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
- ObjectOutputStream os = new ObjectOutputStream(baos);
- os.writeObject(object);
- os.close();
- return baos.toByteArray();
- } catch (IOException e) {
- throw new RuntimeException("Could not serialize object", e);
- }
- }
-
/** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */
public static int nonNegativeHash(Object obj) {
if (obj == null) { return 0; }