aboutsummaryrefslogtreecommitdiff
path: root/network
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2015-11-23 10:45:23 -0800
committerMarcelo Vanzin <vanzin@cloudera.com>2015-11-23 10:45:23 -0800
commit5231cd5acaae69d735ba3209531705cc222f3cfb (patch)
tree5d323395ff1c9a1386bf289f4fa05827fcd1b715 /network
parent5fd86e4fc2e06d2403ca538ae417580c93b69e06 (diff)
downloadspark-5231cd5acaae69d735ba3209531705cc222f3cfb.tar.gz
spark-5231cd5acaae69d735ba3209531705cc222f3cfb.tar.bz2
spark-5231cd5acaae69d735ba3209531705cc222f3cfb.zip
[SPARK-11762][NETWORK] Account for active streams when couting outstanding requests.
This way the timeout handling code can correctly close "hung" channels that are processing streams. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #9747 from vanzin/SPARK-11762.
Diffstat (limited to 'network')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java12
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java15
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java27
3 files changed, 51 insertions, 3 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
index 02230a00e6..88ba3ccebd 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
@@ -30,13 +30,19 @@ import org.apache.spark.network.util.TransportFrameDecoder;
*/
class StreamInterceptor implements TransportFrameDecoder.Interceptor {
+ private final TransportResponseHandler handler;
private final String streamId;
private final long byteCount;
private final StreamCallback callback;
private volatile long bytesRead;
- StreamInterceptor(String streamId, long byteCount, StreamCallback callback) {
+ StreamInterceptor(
+ TransportResponseHandler handler,
+ String streamId,
+ long byteCount,
+ StreamCallback callback) {
+ this.handler = handler;
this.streamId = streamId;
this.byteCount = byteCount;
this.callback = callback;
@@ -45,11 +51,13 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor {
@Override
public void exceptionCaught(Throwable cause) throws Exception {
+ handler.deactivateStream();
callback.onFailure(streamId, cause);
}
@Override
public void channelInactive() throws Exception {
+ handler.deactivateStream();
callback.onFailure(streamId, new ClosedChannelException());
}
@@ -65,8 +73,10 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor {
RuntimeException re = new IllegalStateException(String.format(
"Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead));
callback.onFailure(streamId, re);
+ handler.deactivateStream();
throw re;
} else if (bytesRead == byteCount) {
+ handler.deactivateStream();
callback.onComplete(streamId);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index ed3f36af58..cc88991b58 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
+import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -56,6 +57,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
private final Map<Long, RpcResponseCallback> outstandingRpcs;
private final Queue<StreamCallback> streamCallbacks;
+ private volatile boolean streamActive;
/** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
private final AtomicLong timeOfLastRequestNs;
@@ -87,9 +89,15 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
}
public void addStreamCallback(StreamCallback callback) {
+ timeOfLastRequestNs.set(System.nanoTime());
streamCallbacks.offer(callback);
}
+ @VisibleForTesting
+ public void deactivateStream() {
+ streamActive = false;
+ }
+
/**
* Fire the failure callback for all outstanding requests. This is called when we have an
* uncaught exception or pre-mature connection termination.
@@ -177,14 +185,16 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
StreamResponse resp = (StreamResponse) message;
StreamCallback callback = streamCallbacks.poll();
if (callback != null) {
- StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, resp.byteCount,
+ StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
callback);
try {
TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
frameDecoder.setInterceptor(interceptor);
+ streamActive = true;
} catch (Exception e) {
logger.error("Error installing stream handler.", e);
+ deactivateStream();
}
} else {
logger.error("Could not find callback for StreamResponse.");
@@ -208,7 +218,8 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
/** Returns total number of outstanding requests (fetch requests + rpcs) */
public int numOutstandingRequests() {
- return outstandingFetches.size() + outstandingRpcs.size();
+ return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() +
+ (streamActive ? 1 : 0);
}
/** Returns the time in nanoseconds of when the last request was sent out. */
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
index 17a03ebe88..30144f4a9f 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.network;
+import io.netty.channel.Channel;
import io.netty.channel.local.LocalChannel;
import org.junit.Test;
@@ -28,12 +29,16 @@ import static org.mockito.Mockito.*;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.StreamCallback;
import org.apache.spark.network.client.TransportResponseHandler;
import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamResponse;
+import org.apache.spark.network.util.TransportFrameDecoder;
public class TransportResponseHandlerSuite {
@Test
@@ -112,4 +117,26 @@ public class TransportResponseHandlerSuite {
verify(callback, times(1)).onFailure((Throwable) any());
assertEquals(0, handler.numOutstandingRequests());
}
+
+ @Test
+ public void testActiveStreams() {
+ Channel c = new LocalChannel();
+ c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
+ TransportResponseHandler handler = new TransportResponseHandler(c);
+
+ StreamResponse response = new StreamResponse("stream", 1234L, null);
+ StreamCallback cb = mock(StreamCallback.class);
+ handler.addStreamCallback(cb);
+ assertEquals(1, handler.numOutstandingRequests());
+ handler.handle(response);
+ assertEquals(1, handler.numOutstandingRequests());
+ handler.deactivateStream();
+ assertEquals(0, handler.numOutstandingRequests());
+
+ StreamFailure failure = new StreamFailure("stream", "uh-oh");
+ handler.addStreamCallback(cb);
+ assertEquals(1, handler.numOutstandingRequests());
+ handler.handle(failure);
+ assertEquals(0, handler.numOutstandingRequests());
+ }
}