aboutsummaryrefslogtreecommitdiff
path: root/network
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-02-16 12:06:30 -0800
committerShixiong Zhu <shixiong@databricks.com>2016-02-16 12:06:30 -0800
commit5f37aad48cb729a80c4cc25347460f12aafec9fb (patch)
tree16dbd6eb1b8e5f1ccfafaf4755e785e4b3d72320 /network
parentc7d00a24da317c9601a9239ac1cf185fb6647352 (diff)
downloadspark-5f37aad48cb729a80c4cc25347460f12aafec9fb.tar.gz
spark-5f37aad48cb729a80c4cc25347460f12aafec9fb.tar.bz2
spark-5f37aad48cb729a80c4cc25347460f12aafec9fb.zip
[SPARK-13308] ManagedBuffers passed to OneToOneStreamManager need to be freed in non-error cases
ManagedBuffers that are passed to `OneToOneStreamManager.registerStream` need to be freed by the manager once it's done using them. However, the current code only frees them in certain error-cases and not during typical operation. This isn't a major problem today, but it will cause memory leaks after we implement better locking / pinning in the BlockManager (see #10705). This patch modifies the relevant network code so that the ManagedBuffers are freed as soon as the messages containing them are processed by the lower-level Netty message sending code. /cc zsxwing for review. Author: Josh Rosen <joshrosen@databricks.com> Closes #11193 from JoshRosen/add-missing-release-calls-in-network-layer.
Diffstat (limited to 'network')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java6
-rw-r--r--network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java7
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java28
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java1
-rw-r--r--network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java34
-rw-r--r--network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java50
7 files changed, 119 insertions, 9 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
index a415db593a..1861f8d7fd 100644
--- a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
@@ -65,7 +65,11 @@ public abstract class ManagedBuffer {
public abstract ManagedBuffer release();
/**
- * Convert the buffer into an Netty object, used to write the data out.
+ * Convert the buffer into an Netty object, used to write the data out. The return value is either
+ * a {@link io.netty.buffer.ByteBuf} or a {@link io.netty.channel.FileRegion}.
+ *
+ * If this method returns a ByteBuf, then that buffer's reference count will be incremented and
+ * the caller will be responsible for releasing this new reference.
*/
public abstract Object convertToNetty() throws IOException;
}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
index c806bfa45b..4c8802af7a 100644
--- a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
@@ -64,7 +64,7 @@ public final class NettyManagedBuffer extends ManagedBuffer {
@Override
public Object convertToNetty() throws IOException {
- return buf.duplicate();
+ return buf.duplicate().retain();
}
@Override
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index abca22347b..664df57fec 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -54,6 +54,7 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
body = in.body().convertToNetty();
isBodyInFrame = in.isBodyInFrame();
} catch (Exception e) {
+ in.body().release();
if (in instanceof AbstractResponseMessage) {
AbstractResponseMessage resp = (AbstractResponseMessage) in;
// Re-encode this message as a failure response.
@@ -80,8 +81,10 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
in.encode(header);
assert header.writableBytes() == 0;
- if (body != null && bodyLength > 0) {
- out.add(new MessageWithHeader(header, body, bodyLength));
+ if (body != null) {
+ // We transfer ownership of the reference on in.body() to MessageWithHeader.
+ // This reference will be freed when MessageWithHeader.deallocate() is called.
+ out.add(new MessageWithHeader(in.body(), header, body, bodyLength));
} else {
out.add(header);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
index d686a95146..66227f96a1 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
@@ -19,6 +19,7 @@ package org.apache.spark.network.protocol;
import java.io.IOException;
import java.nio.channels.WritableByteChannel;
+import javax.annotation.Nullable;
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
@@ -26,6 +27,8 @@ import io.netty.channel.FileRegion;
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCountUtil;
+import org.apache.spark.network.buffer.ManagedBuffer;
+
/**
* A wrapper message that holds two separate pieces (a header and a body).
*
@@ -33,15 +36,35 @@ import io.netty.util.ReferenceCountUtil;
*/
class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
+ @Nullable private final ManagedBuffer managedBuffer;
private final ByteBuf header;
private final int headerLength;
private final Object body;
private final long bodyLength;
private long totalBytesTransferred;
- MessageWithHeader(ByteBuf header, Object body, long bodyLength) {
+ /**
+ * Construct a new MessageWithHeader.
+ *
+ * @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to
+ * be passed in so that the buffer can be freed when this message is
+ * deallocated. Ownership of the caller's reference to this buffer is
+ * transferred to this class, so if the caller wants to continue to use the
+ * ManagedBuffer in other messages then they will need to call retain() on
+ * it before passing it to this constructor. This may be null if and only if
+ * `body` is a {@link FileRegion}.
+ * @param header the message header.
+ * @param body the message body. Must be either a {@link ByteBuf} or a {@link FileRegion}.
+ * @param bodyLength the length of the message body, in bytes.
+ */
+ MessageWithHeader(
+ @Nullable ManagedBuffer managedBuffer,
+ ByteBuf header,
+ Object body,
+ long bodyLength) {
Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion,
"Body must be a ByteBuf or a FileRegion.");
+ this.managedBuffer = managedBuffer;
this.header = header;
this.headerLength = header.readableBytes();
this.body = body;
@@ -99,6 +122,9 @@ class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
protected void deallocate() {
header.release();
ReferenceCountUtil.release(body);
+ if (managedBuffer != null) {
+ managedBuffer.release();
+ }
}
private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException {
diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index e671854da1..ea9e735e0a 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -20,7 +20,6 @@ package org.apache.spark.network.server;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
-import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
index 6c98e733b4..fbbe4b7014 100644
--- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
@@ -26,9 +26,13 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.FileRegion;
import io.netty.util.AbstractReferenceCounted;
import org.junit.Test;
+import org.mockito.Mockito;
import static org.junit.Assert.*;
+import org.apache.spark.network.TestManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
import org.apache.spark.network.util.ByteArrayWritableChannel;
public class MessageWithHeaderSuite {
@@ -46,20 +50,43 @@ public class MessageWithHeaderSuite {
@Test
public void testByteBufBody() throws Exception {
ByteBuf header = Unpooled.copyLong(42);
- ByteBuf body = Unpooled.copyLong(84);
- MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes());
+ ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84);
+ assertEquals(1, header.refCnt());
+ assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt());
+ ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer);
+ Object body = managedBuf.convertToNetty();
+ assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt());
+ assertEquals(1, header.refCnt());
+
+ MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size());
ByteBuf result = doWrite(msg, 1);
assertEquals(msg.count(), result.readableBytes());
assertEquals(42, result.readLong());
assertEquals(84, result.readLong());
+
+ assert(msg.release());
+ assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt());
+ assertEquals(0, header.refCnt());
+ }
+
+ @Test
+ public void testDeallocateReleasesManagedBuffer() throws Exception {
+ ByteBuf header = Unpooled.copyLong(42);
+ ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84));
+ ByteBuf body = (ByteBuf) managedBuf.convertToNetty();
+ assertEquals(2, body.refCnt());
+ MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes());
+ assert(msg.release());
+ Mockito.verify(managedBuf, Mockito.times(1)).release();
+ assertEquals(0, body.refCnt());
}
private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception {
ByteBuf header = Unpooled.copyLong(42);
int headerLength = header.readableBytes();
TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall);
- MessageWithHeader msg = new MessageWithHeader(header, region, region.count());
+ MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count());
ByteBuf result = doWrite(msg, totalWrites / writesPerCall);
assertEquals(headerLength + region.count(), result.readableBytes());
@@ -67,6 +94,7 @@ public class MessageWithHeaderSuite {
for (long i = 0; i < 8; i++) {
assertEquals(i, result.readLong());
}
+ assert(msg.release());
}
private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception {
diff --git a/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
new file mode 100644
index 0000000000..c647525d8f
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import io.netty.channel.Channel;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import org.apache.spark.network.TestManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+public class OneForOneStreamManagerSuite {
+
+ @Test
+ public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception {
+ OneForOneStreamManager manager = new OneForOneStreamManager();
+ List<ManagedBuffer> buffers = new ArrayList<>();
+ TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10));
+ TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
+ buffers.add(buffer1);
+ buffers.add(buffer2);
+ long streamId = manager.registerStream("appId", buffers.iterator());
+
+ Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
+ manager.registerChannel(dummyChannel, streamId);
+
+ manager.connectionTerminated(dummyChannel);
+
+ Mockito.verify(buffer1, Mockito.times(1)).release();
+ Mockito.verify(buffer2, Mockito.times(1)).release();
+ }
+}