aboutsummaryrefslogtreecommitdiff
path: root/network
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-05-01 11:59:12 -0700
committerAaron Davidson <aaron@databricks.com>2015-05-01 11:59:12 -0700
commit16860327286bc08b4e2283d51b4c8fe024ba5006 (patch)
tree8ed64b241ffdd85db091c2eb9be7b1c34dc38c40 /network
parent1262e310cd294c8fd936c55c3281ed855824ea27 (diff)
downloadspark-16860327286bc08b4e2283d51b4c8fe024ba5006.tar.gz
spark-16860327286bc08b4e2283d51b4c8fe024ba5006.tar.bz2
spark-16860327286bc08b4e2283d51b4c8fe024ba5006.zip
[SPARK-7183] [NETWORK] Fix memory leak of TransportRequestHandler.streamIds
JIRA: https://issues.apache.org/jira/browse/SPARK-7183 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #5743 from viirya/fix_requesthandler_memory_leak and squashes the following commits: cf2c086 [Liang-Chi Hsieh] For comments. 97e205c [Liang-Chi Hsieh] Remove unused import. d35f19a [Liang-Chi Hsieh] For comments. f9a0c37 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into fix_requesthandler_memory_leak 45908b7 [Liang-Chi Hsieh] for style. 17f020f [Liang-Chi Hsieh] Remove unused import. 37a4b6c [Liang-Chi Hsieh] Remove streamIds from TransportRequestHandler. 3b3f38a [Liang-Chi Hsieh] Fix memory leak of TransportRequestHandler.streamIds.
Diffstat (limited to 'network')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java35
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/StreamManager.java19
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java14
3 files changed, 44 insertions, 24 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index a6d390e13f..c95e64e8e2 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -20,14 +20,18 @@ package org.apache.spark.network.server;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
+import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
+import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer;
+import com.google.common.base.Preconditions;
+
/**
* StreamManager which allows registration of an Iterator&lt;ManagedBuffer&gt;, which are individually
* fetched as chunks by the client. Each registered buffer is one chunk.
@@ -36,18 +40,21 @@ public class OneForOneStreamManager extends StreamManager {
private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class);
private final AtomicLong nextStreamId;
- private final Map<Long, StreamState> streams;
+ private final ConcurrentHashMap<Long, StreamState> streams;
/** State of a single stream. */
private static class StreamState {
final Iterator<ManagedBuffer> buffers;
+ // The channel associated to the stream
+ Channel associatedChannel = null;
+
// Used to keep track of the index of the buffer that the user has retrieved, just to ensure
// that the caller only requests each chunk one at a time, in order.
int curChunk = 0;
StreamState(Iterator<ManagedBuffer> buffers) {
- this.buffers = buffers;
+ this.buffers = Preconditions.checkNotNull(buffers);
}
}
@@ -59,6 +66,13 @@ public class OneForOneStreamManager extends StreamManager {
}
@Override
+ public void registerChannel(Channel channel, long streamId) {
+ if (streams.containsKey(streamId)) {
+ streams.get(streamId).associatedChannel = channel;
+ }
+ }
+
+ @Override
public ManagedBuffer getChunk(long streamId, int chunkIndex) {
StreamState state = streams.get(streamId);
if (chunkIndex != state.curChunk) {
@@ -80,12 +94,17 @@ public class OneForOneStreamManager extends StreamManager {
}
@Override
- public void connectionTerminated(long streamId) {
- // Release all remaining buffers.
- StreamState state = streams.remove(streamId);
- if (state != null && state.buffers != null) {
- while (state.buffers.hasNext()) {
- state.buffers.next().release();
+ public void connectionTerminated(Channel channel) {
+ // Close all streams which have been associated with the channel.
+ for (Map.Entry<Long, StreamState> entry: streams.entrySet()) {
+ StreamState state = entry.getValue();
+ if (state.associatedChannel == channel) {
+ streams.remove(entry.getKey());
+
+ // Release all remaining buffers.
+ while (state.buffers.hasNext()) {
+ state.buffers.next().release();
+ }
}
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
index 5a9a14a180..929f789bf9 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.server;
+import io.netty.channel.Channel;
+
import org.apache.spark.network.buffer.ManagedBuffer;
/**
@@ -44,9 +46,18 @@ public abstract class StreamManager {
public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);
/**
- * Indicates that the TCP connection that was tied to the given stream has been terminated. After
- * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned
- * up.
+ * Associates a stream with a single client connection, which is guaranteed to be the only reader
+ * of the stream. The getChunk() method will be called serially on this connection and once the
+ * connection is closed, the stream will never be used again, enabling cleanup.
+ *
+ * This must be called before the first getChunk() on the stream, but it may be invoked multiple
+ * times with the same channel and stream id.
+ */
+ public void registerChannel(Channel channel, long streamId) { }
+
+ /**
+ * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
+ * to read from the associated streams again, so any state can be cleaned up.
*/
- public void connectionTerminated(long streamId) { }
+ public void connectionTerminated(Channel channel) { }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 1580180cc1..e5159ab56d 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -17,10 +17,7 @@
package org.apache.spark.network.server;
-import java.util.Set;
-
import com.google.common.base.Throwables;
-import com.google.common.collect.Sets;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
@@ -62,9 +59,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
/** Returns each chunk part of a stream. */
private final StreamManager streamManager;
- /** List of all stream ids that have been read on this handler, used for cleanup. */
- private final Set<Long> streamIds;
-
public TransportRequestHandler(
Channel channel,
TransportClient reverseClient,
@@ -73,7 +67,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
this.reverseClient = reverseClient;
this.rpcHandler = rpcHandler;
this.streamManager = rpcHandler.getStreamManager();
- this.streamIds = Sets.newHashSet();
}
@Override
@@ -82,10 +75,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
@Override
public void channelUnregistered() {
- // Inform the StreamManager that these streams will no longer be read from.
- for (long streamId : streamIds) {
- streamManager.connectionTerminated(streamId);
- }
+ streamManager.connectionTerminated(channel);
rpcHandler.connectionTerminated(reverseClient);
}
@@ -102,12 +92,12 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
private void processFetchRequest(final ChunkFetchRequest req) {
final String client = NettyUtils.getRemoteAddress(channel);
- streamIds.add(req.streamChunkId.streamId);
logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId);
ManagedBuffer buf;
try {
+ streamManager.registerChannel(channel, req.streamChunkId.streamId);
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
} catch (Exception e) {
logger.error(String.format(