aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-04-02 14:51:00 -0700
committerReynold Xin <rxin@databricks.com>2015-04-02 14:51:00 -0700
commitd82e73239118fe99535eaa68be9bf37e837bebe4 (patch)
treebd9451eeeb48359004fdbd887536b7e0a1d3cdee
parent8fa09a480848faf4eda263cc1e79e0dd56a52605 (diff)
downloadspark-d82e73239118fe99535eaa68be9bf37e837bebe4.tar.gz
spark-d82e73239118fe99535eaa68be9bf37e837bebe4.tar.bz2
spark-d82e73239118fe99535eaa68be9bf37e837bebe4.zip
[SPARK-6578] [core] Fix thread-safety issue in outbound path of network library.
While the inbound path of a netty pipeline is thread-safe, the outbound path is not. That means that multiple threads can compete to write messages to the next stage of the pipeline. The network library sometimes breaks a single RPC message into multiple buffers internally to avoid copying data (see MessageEncoder). This can result in the following scenario (where "FxBy" means "frame x, buffer y"): T1 F1B1 F1B2 \ \ \ \ socket F1B1 F2B1 F1B2 F2B2 / / / / T2 F2B1 F2B2 And the frames now cannot be rebuilt on the receiving side because the different messages have been mixed up on the wire. The fix wraps these multi-buffer messages into a `FileRegion` object so that these messages are written "atomically" to the next pipeline handler. Author: Reynold Xin <rxin@databricks.com> Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #5336 from vanzin/SPARK-6578-1.2 and squashes the following commits: 4d3395e [Reynold Xin] [SPARK-6578] Small rewrite to make the logic more clear in MessageWithHeader.transferTo. 526f230 [Marcelo Vanzin] [SPARK-6578] [core] Fix thread-safety issue in outbound path of network library.
-rw-r--r--network/common/pom.xml5
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java6
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java109
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java55
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java46
-rw-r--r--network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java129
-rw-r--r--network/common/src/test/resources/log4j.properties27
7 files changed, 367 insertions, 10 deletions
diff --git a/network/common/pom.xml b/network/common/pom.xml
index e4b43a0a7c..58d09f0361 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -80,6 +80,11 @@
<artifactId>scalatest_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-log4j12</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 91d1e8a538..0f999f5dfe 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -72,9 +72,11 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
in.encode(header);
assert header.writableBytes() == 0;
- out.add(header);
if (body != null && bodyLength > 0) {
- out.add(body);
+ out.add(new MessageWithHeader(header, body, bodyLength));
+ } else {
+ out.add(header);
}
}
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
new file mode 100644
index 0000000000..d686a95146
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.io.IOException;
+import java.nio.channels.WritableByteChannel;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.FileRegion;
+import io.netty.util.AbstractReferenceCounted;
+import io.netty.util.ReferenceCountUtil;
+
+/**
+ * A wrapper message that holds two separate pieces (a header and a body).
+ *
+ * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion.
+ */
+class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
+
+ private final ByteBuf header;
+ private final int headerLength;
+ private final Object body;
+ private final long bodyLength;
+ private long totalBytesTransferred;
+
+ MessageWithHeader(ByteBuf header, Object body, long bodyLength) {
+ Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion,
+ "Body must be a ByteBuf or a FileRegion.");
+ this.header = header;
+ this.headerLength = header.readableBytes();
+ this.body = body;
+ this.bodyLength = bodyLength;
+ }
+
+ @Override
+ public long count() {
+ return headerLength + bodyLength;
+ }
+
+ @Override
+ public long position() {
+ return 0;
+ }
+
+ @Override
+ public long transfered() {
+ return totalBytesTransferred;
+ }
+
+ /**
+ * This code is more complicated than you would think because we might require multiple
+ * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting.
+ *
+ * The contract is that the caller will ensure position is properly set to the total number
+ * of bytes transferred so far (i.e. value returned by transfered()).
+ */
+ @Override
+ public long transferTo(final WritableByteChannel target, final long position) throws IOException {
+ Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position.");
+ // Bytes written for header in this call.
+ long writtenHeader = 0;
+ if (header.readableBytes() > 0) {
+ writtenHeader = copyByteBuf(header, target);
+ totalBytesTransferred += writtenHeader;
+ if (header.readableBytes() > 0) {
+ return writtenHeader;
+ }
+ }
+
+ // Bytes written for body in this call.
+ long writtenBody = 0;
+ if (body instanceof FileRegion) {
+ writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength);
+ } else if (body instanceof ByteBuf) {
+ writtenBody = copyByteBuf((ByteBuf) body, target);
+ }
+ totalBytesTransferred += writtenBody;
+
+ return writtenHeader + writtenBody;
+ }
+
+ @Override
+ protected void deallocate() {
+ header.release();
+ ReferenceCountUtil.release(body);
+ }
+
+ private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException {
+ int written = target.write(buf.nioBuffer());
+ buf.skipBytes(written);
+ return written;
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java
new file mode 100644
index 0000000000..b525ed69fc
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+
+public class ByteArrayWritableChannel implements WritableByteChannel {
+
+ private final byte[] data;
+ private int offset;
+
+ public ByteArrayWritableChannel(int size) {
+ this.data = new byte[size];
+ this.offset = 0;
+ }
+
+ public byte[] getData() {
+ return data;
+ }
+
+ @Override
+ public int write(ByteBuffer src) {
+ int available = src.remaining();
+ src.get(data, offset, available);
+ offset += available;
+ return available;
+ }
+
+ @Override
+ public void close() {
+
+ }
+
+ @Override
+ public boolean isOpen() {
+ return true;
+ }
+
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index 43dc0cf8c7..860dd6d9b3 100644
--- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -17,26 +17,34 @@
package org.apache.spark.network;
+import java.util.List;
+
+import com.google.common.primitives.Ints;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.FileRegion;
import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.MessageToMessageEncoder;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
-import org.apache.spark.network.protocol.Message;
-import org.apache.spark.network.protocol.StreamChunkId;
-import org.apache.spark.network.protocol.ChunkFetchRequest;
import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchRequest;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
-import org.apache.spark.network.protocol.RpcRequest;
-import org.apache.spark.network.protocol.RpcFailure;
-import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.Message;
import org.apache.spark.network.protocol.MessageDecoder;
import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.util.NettyUtils;
public class ProtocolSuite {
private void testServerToClient(Message msg) {
- EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder());
+ EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
+ new MessageEncoder());
serverChannel.writeOutbound(msg);
EmbeddedChannel clientChannel = new EmbeddedChannel(
@@ -51,7 +59,8 @@ public class ProtocolSuite {
}
private void testClientToServer(Message msg) {
- EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder());
+ EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
+ new MessageEncoder());
clientChannel.writeOutbound(msg);
EmbeddedChannel serverChannel = new EmbeddedChannel(
@@ -83,4 +92,25 @@ public class ProtocolSuite {
testServerToClient(new RpcFailure(0, "this is an error"));
testServerToClient(new RpcFailure(0, ""));
}
+
+ /**
+ * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer
+ * bytes, but messages, so this is needed so that the frame decoder on the receiving side can
+ * understand what MessageWithHeader actually contains.
+ */
+ private static class FileRegionEncoder extends MessageToMessageEncoder<FileRegion> {
+
+ @Override
+ public void encode(ChannelHandlerContext ctx, FileRegion in, List<Object> out)
+ throws Exception {
+
+ ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count()));
+ while (in.transfered() < in.count()) {
+ in.transferTo(channel, in.transfered());
+ }
+ out.add(Unpooled.wrappedBuffer(channel.getData()));
+ }
+
+ }
+
}
diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
new file mode 100644
index 0000000000..ff985096d7
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.FileRegion;
+import io.netty.util.AbstractReferenceCounted;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.ByteArrayWritableChannel;
+
+public class MessageWithHeaderSuite {
+
+ @Test
+ public void testSingleWrite() throws Exception {
+ testFileRegionBody(8, 8);
+ }
+
+ @Test
+ public void testShortWrite() throws Exception {
+ testFileRegionBody(8, 1);
+ }
+
+ @Test
+ public void testByteBufBody() throws Exception {
+ ByteBuf header = Unpooled.copyLong(42);
+ ByteBuf body = Unpooled.copyLong(84);
+ MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes());
+
+ ByteBuf result = doWrite(msg, 1);
+ assertEquals(msg.count(), result.readableBytes());
+ assertEquals(42, result.readLong());
+ assertEquals(84, result.readLong());
+ }
+
+ private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception {
+ ByteBuf header = Unpooled.copyLong(42);
+ int headerLength = header.readableBytes();
+ TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall);
+ MessageWithHeader msg = new MessageWithHeader(header, region, region.count());
+
+ ByteBuf result = doWrite(msg, totalWrites / writesPerCall);
+ assertEquals(headerLength + region.count(), result.readableBytes());
+ assertEquals(42, result.readLong());
+ for (long i = 0; i < 8; i++) {
+ assertEquals(i, result.readLong());
+ }
+ }
+
+ private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception {
+ int writes = 0;
+ ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count());
+ while (msg.transfered() < msg.count()) {
+ msg.transferTo(channel, msg.transfered());
+ writes++;
+ }
+ assertTrue("Not enough writes!", minExpectedWrites <= writes);
+ return Unpooled.wrappedBuffer(channel.getData());
+ }
+
+ private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion {
+
+ private final int writeCount;
+ private final int writesPerCall;
+ private int written;
+
+ TestFileRegion(int totalWrites, int writesPerCall) {
+ this.writeCount = totalWrites;
+ this.writesPerCall = writesPerCall;
+ }
+
+ @Override
+ public long count() {
+ return 8 * writeCount;
+ }
+
+ @Override
+ public long position() {
+ return 0;
+ }
+
+ @Override
+ public long transfered() {
+ return 8 * written;
+ }
+
+ @Override
+ public long transferTo(WritableByteChannel target, long position) throws IOException {
+ for (int i = 0; i < writesPerCall; i++) {
+ ByteBuf buf = Unpooled.copyLong((position / 8) + i);
+ ByteBuffer nio = buf.nioBuffer();
+ while (nio.remaining() > 0) {
+ target.write(nio);
+ }
+ buf.release();
+ written++;
+ }
+ return 8 * writesPerCall;
+ }
+
+ @Override
+ protected void deallocate() {
+ }
+
+ }
+
+}
diff --git a/network/common/src/test/resources/log4j.properties b/network/common/src/test/resources/log4j.properties
new file mode 100644
index 0000000000..e8da774f7c
--- /dev/null
+++ b/network/common/src/test/resources/log4j.properties
@@ -0,0 +1,27 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file target/unit-tests.log
+log4j.rootCategory=DEBUG, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=true
+log4j.appender.file.file=target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+
+# Silence verbose logs from 3rd-party libraries.
+log4j.logger.io.netty=INFO