aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--network/common/src/main/java/org/apache/spark/network/TransportContext.java5
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java14
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java33
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java41
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java2
-rw-r--r--network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java277
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java21
7 files changed, 375 insertions, 18 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index f0a89c9d91..3fe69b1bd8 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -22,6 +22,7 @@ import java.util.List;
import com.google.common.collect.Lists;
import io.netty.channel.Channel;
import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.timeout.IdleStateHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -106,6 +107,7 @@ public class TransportContext {
.addLast("encoder", encoder)
.addLast("frameDecoder", NettyUtils.createFrameDecoder())
.addLast("decoder", decoder)
+ .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
// would require more logic to guarantee if this were not part of the same event loop.
.addLast("handler", channelHandler);
@@ -126,7 +128,8 @@ public class TransportContext {
TransportClient client = new TransportClient(channel, responseHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
rpcHandler);
- return new TransportChannelHandler(client, responseHandler, requestHandler);
+ return new TransportChannelHandler(client, responseHandler, requestHandler,
+ conf.connectionTimeoutMs());
}
public TransportConf getConf() { return conf; }
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index 2044afb0d8..94fc21af5e 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -20,8 +20,8 @@ package org.apache.spark.network.client;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
-import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -50,13 +50,18 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
private final Map<Long, RpcResponseCallback> outstandingRpcs;
+ /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
+ private final AtomicLong timeOfLastRequestNs;
+
public TransportResponseHandler(Channel channel) {
this.channel = channel;
this.outstandingFetches = new ConcurrentHashMap<StreamChunkId, ChunkReceivedCallback>();
this.outstandingRpcs = new ConcurrentHashMap<Long, RpcResponseCallback>();
+ this.timeOfLastRequestNs = new AtomicLong(0);
}
public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
+ timeOfLastRequestNs.set(System.nanoTime());
outstandingFetches.put(streamChunkId, callback);
}
@@ -65,6 +70,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
}
public void addRpcRequest(long requestId, RpcResponseCallback callback) {
+ timeOfLastRequestNs.set(System.nanoTime());
outstandingRpcs.put(requestId, callback);
}
@@ -161,8 +167,12 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
}
/** Returns total number of outstanding requests (fetch requests + rpcs) */
- @VisibleForTesting
public int numOutstandingRequests() {
return outstandingFetches.size() + outstandingRpcs.size();
}
+
+ /** Returns the time in nanoseconds of when the last request was sent out. */
+ public long getTimeOfLastRequestNs() {
+ return timeOfLastRequestNs.get();
+ }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
index e491367fa4..8e0ee709e3 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -19,6 +19,8 @@ package org.apache.spark.network.server;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
+import io.netty.handler.timeout.IdleState;
+import io.netty.handler.timeout.IdleStateEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -40,6 +42,11 @@ import org.apache.spark.network.util.NettyUtils;
* Client.
* This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler,
* for the Client's responses to the Server's requests.
+ *
+ * This class also handles timeouts from a {@link io.netty.handler.timeout.IdleStateHandler}.
+ * We consider a connection timed out if there are outstanding fetch or RPC requests but no traffic
+ * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not
+ * timeout if the client is continuously sending but getting no responses, for simplicity.
*/
public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {
private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);
@@ -47,14 +54,17 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
private final TransportClient client;
private final TransportResponseHandler responseHandler;
private final TransportRequestHandler requestHandler;
+ private final long requestTimeoutNs;
public TransportChannelHandler(
TransportClient client,
TransportResponseHandler responseHandler,
- TransportRequestHandler requestHandler) {
+ TransportRequestHandler requestHandler,
+ long requestTimeoutMs) {
this.client = client;
this.responseHandler = responseHandler;
this.requestHandler = requestHandler;
+ this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000;
}
public TransportClient getClient() {
@@ -93,4 +103,25 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
responseHandler.handle((ResponseMessage) request);
}
}
+
+ /** Triggered based on events from an {@link io.netty.handler.timeout.IdleStateHandler}. */
+ @Override
+ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
+ if (evt instanceof IdleStateEvent) {
+ IdleStateEvent e = (IdleStateEvent) evt;
+ // See class comment for timeout semantics. In addition to ensuring we only timeout while
+ // there are outstanding requests, we also do a secondary consistency check to ensure
+ // there's no race between the idle timeout and incrementing the numOutstandingRequests.
+ boolean hasInFlightRequests = responseHandler.numOutstandingRequests() > 0;
+ boolean isActuallyOverdue =
+ System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs;
+ if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue) {
+ String address = NettyUtils.getRemoteAddress(ctx.channel());
+ logger.error("Connection to {} has been quiet for {} ms while there are outstanding " +
+ "requests. Assuming connection is dead; please adjust spark.network.timeout if this " +
+ "is wrong.", address, requestTimeoutNs / 1000 / 1000);
+ ctx.close();
+ }
+ }
+ }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java
new file mode 100644
index 0000000000..668d2356b9
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import com.google.common.collect.Maps;
+
+import java.util.Map;
+import java.util.NoSuchElementException;
+
+/** ConfigProvider based on a Map (copied in the constructor). */
+public class MapConfigProvider extends ConfigProvider {
+ private final Map<String, String> config;
+
+ public MapConfigProvider(Map<String, String> config) {
+ this.config = Maps.newHashMap(config);
+ }
+
+ @Override
+ public String get(String name) {
+ String value = config.get(name);
+ if (value == null) {
+ throw new NoSuchElementException(name);
+ }
+ return value;
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
index dabd6261d2..26c6399ce7 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
@@ -98,7 +98,7 @@ public class NettyUtils {
return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8);
}
- /** Returns the remote address on the channel or "&lt;remote address&gt;" if none exists. */
+ /** Returns the remote address on the channel or "&lt;unknown remote&gt;" if none exists. */
public static String getRemoteAddress(Channel channel) {
if (channel != null && channel.remoteAddress() != null) {
return channel.remoteAddress().toString();
diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
new file mode 100644
index 0000000000..84ebb337e6
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import com.google.common.collect.Maps;
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.MapConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+import org.junit.*;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.*;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Suite which ensures that requests that go without a response for the network timeout period are
+ * failed, and the connection closed.
+ *
+ * In this suite, we use 2 seconds as the connection timeout, with some slack given in the tests,
+ * to ensure stability in different test environments.
+ */
+public class RequestTimeoutIntegrationSuite {
+
+ private TransportServer server;
+ private TransportClientFactory clientFactory;
+
+ private StreamManager defaultManager;
+ private TransportConf conf;
+
+ // A large timeout that "shouldn't happen", for the sake of faulty tests not hanging forever.
+ private final int FOREVER = 60 * 1000;
+
+ @Before
+ public void setUp() throws Exception {
+ Map<String, String> configMap = Maps.newHashMap();
+ configMap.put("spark.shuffle.io.connectionTimeout", "2s");
+ conf = new TransportConf(new MapConfigProvider(configMap));
+
+ defaultManager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
+ @After
+ public void tearDown() {
+ if (server != null) {
+ server.close();
+ }
+ if (clientFactory != null) {
+ clientFactory.close();
+ }
+ }
+
+ // Basic suite: First request completes quickly, and second waits for longer than network timeout.
+ @Test
+ public void timeoutInactiveRequests() throws Exception {
+ final Semaphore semaphore = new Semaphore(1);
+ final byte[] response = new byte[16];
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ try {
+ semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
+ callback.onSuccess(response);
+ } catch (InterruptedException e) {
+ // do nothing
+ }
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return defaultManager;
+ }
+ };
+
+ TransportContext context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+
+ // First completes quickly (semaphore starts at 1).
+ TestCallback callback0 = new TestCallback();
+ synchronized (callback0) {
+ client.sendRpc(new byte[0], callback0);
+ callback0.wait(FOREVER);
+ assert (callback0.success.length == response.length);
+ }
+
+ // Second times out after 2 seconds, with slack. Must be IOException.
+ TestCallback callback1 = new TestCallback();
+ synchronized (callback1) {
+ client.sendRpc(new byte[0], callback1);
+ callback1.wait(4 * 1000);
+ assert (callback1.failure != null);
+ assert (callback1.failure instanceof IOException);
+ }
+ semaphore.release();
+ }
+
+ // A timeout will cause the connection to be closed, invalidating the current TransportClient.
+ // It should be the case that requesting a client from the factory produces a new, valid one.
+ @Test
+ public void timeoutCleanlyClosesClient() throws Exception {
+ final Semaphore semaphore = new Semaphore(0);
+ final byte[] response = new byte[16];
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ try {
+ semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
+ callback.onSuccess(response);
+ } catch (InterruptedException e) {
+ // do nothing
+ }
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return defaultManager;
+ }
+ };
+
+ TransportContext context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+
+ // First request should eventually fail.
+ TransportClient client0 =
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ TestCallback callback0 = new TestCallback();
+ synchronized (callback0) {
+ client0.sendRpc(new byte[0], callback0);
+ callback0.wait(FOREVER);
+ assert (callback0.failure instanceof IOException);
+ assert (!client0.isActive());
+ }
+
+ // Increment the semaphore and the second request should succeed quickly.
+ semaphore.release(2);
+ TransportClient client1 =
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ TestCallback callback1 = new TestCallback();
+ synchronized (callback1) {
+ client1.sendRpc(new byte[0], callback1);
+ callback1.wait(FOREVER);
+ assert (callback1.success.length == response.length);
+ assert (callback1.failure == null);
+ }
+ }
+
+ // The timeout is relative to the LAST request sent, which is kinda weird, but still.
+ // This test also makes sure the timeout works for Fetch requests as well as RPCs.
+ @Test
+ public void furtherRequestsDelay() throws Exception {
+ final byte[] response = new byte[16];
+ final StreamManager manager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS);
+ return new NioManagedBuffer(ByteBuffer.wrap(response));
+ }
+ };
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return manager;
+ }
+ };
+
+ TransportContext context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+
+ // Send one request, which will eventually fail.
+ TestCallback callback0 = new TestCallback();
+ client.fetchChunk(0, 0, callback0);
+ Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS);
+
+ // Send a second request before the first has failed.
+ TestCallback callback1 = new TestCallback();
+ client.fetchChunk(0, 1, callback1);
+ Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS);
+
+ synchronized (callback0) {
+ // not complete yet, but should complete soon
+ assert (callback0.success == null && callback0.failure == null);
+ callback0.wait(2 * 1000);
+ assert (callback0.failure instanceof IOException);
+ }
+
+ synchronized (callback1) {
+ // failed at same time as previous
+ assert (callback0.failure instanceof IOException);
+ }
+ }
+
+ /**
+ * Callback which sets 'success' or 'failure' on completion.
+ * Additionally notifies all waiters on this callback when invoked.
+ */
+ class TestCallback implements RpcResponseCallback, ChunkReceivedCallback {
+
+ byte[] success;
+ Throwable failure;
+
+ @Override
+ public void onSuccess(byte[] response) {
+ synchronized(this) {
+ success = response;
+ this.notifyAll();
+ }
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ synchronized(this) {
+ failure = e;
+ this.notifyAll();
+ }
+ }
+
+ @Override
+ public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ synchronized(this) {
+ try {
+ success = buffer.nioByteBuffer().array();
+ this.notifyAll();
+ } catch (IOException e) {
+ // weird
+ }
+ }
+ }
+
+ @Override
+ public void onFailure(int chunkIndex, Throwable e) {
+ synchronized(this) {
+ failure = e;
+ this.notifyAll();
+ }
+ }
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
index 416dc1b969..35de5e57cc 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -20,10 +20,11 @@ package org.apache.spark.network;
import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
-import java.util.NoSuchElementException;
+import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
+import com.google.common.collect.Maps;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -36,9 +37,9 @@ import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.util.ConfigProvider;
-import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class TransportClientFactorySuite {
@@ -70,16 +71,10 @@ public class TransportClientFactorySuite {
*/
private void testClientReuse(final int maxConnections, boolean concurrent)
throws IOException, InterruptedException {
- TransportConf conf = new TransportConf(new ConfigProvider() {
- @Override
- public String get(String name) {
- if (name.equals("spark.shuffle.io.numConnectionsPerPeer")) {
- return Integer.toString(maxConnections);
- } else {
- throw new NoSuchElementException();
- }
- }
- });
+
+ Map<String, String> configMap = Maps.newHashMap();
+ configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections));
+ TransportConf conf = new TransportConf(new MapConfigProvider(configMap));
RpcHandler rpcHandler = new NoOpRpcHandler();
TransportContext context = new TransportContext(conf, rpcHandler);