aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala21
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java16
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java13
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java3
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportServer.java8
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java14
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportConf.java17
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java7
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java31
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java9
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java234
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java4
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java18
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java6
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java310
16 files changed, 668 insertions, 45 deletions
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 0d1fc81d2a..b937ea825f 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -27,7 +27,7 @@ import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCal
import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
import org.apache.spark.network.server._
-import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher}
+import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
@@ -71,9 +71,22 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
listener: BlockFetchingListener): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
- val client = clientFactory.createClient(host, port)
- new OneForOneBlockFetcher(client, blockIds.toArray, listener)
- .start(OpenBlocks(blockIds.map(BlockId.apply)))
+ val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
+ override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
+ val client = clientFactory.createClient(host, port)
+ new OneForOneBlockFetcher(client, blockIds.toArray, listener)
+ .start(OpenBlocks(blockIds.map(BlockId.apply)))
+ }
+ }
+
+ val maxRetries = transportConf.maxIORetries()
+ if (maxRetries > 0) {
+ // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
+ // a bug in this code. We should remove the if statement once we're sure of the stability.
+ new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
+ } else {
+ blockFetchStarter.createAndStart(blockIds, listener)
+ }
} catch {
case e: Exception =>
logError("Exception while beginning fetchBlocks", e)
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index a08cee02dd..4e944114e8 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -18,7 +18,9 @@
package org.apache.spark.network.client;
import java.io.Closeable;
+import java.io.IOException;
import java.util.UUID;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import com.google.common.base.Objects;
@@ -116,8 +118,12 @@ public class TransportClient implements Closeable {
serverAddr, future.cause());
logger.error(errorMsg, future.cause());
handler.removeFetchRequest(streamChunkId);
- callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause()));
channel.close();
+ try {
+ callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
}
}
});
@@ -147,8 +153,12 @@ public class TransportClient implements Closeable {
serverAddr, future.cause());
logger.error(errorMsg, future.cause());
handler.removeRpcRequest(requestId);
- callback.onFailure(new RuntimeException(errorMsg, future.cause()));
channel.close();
+ try {
+ callback.onFailure(new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
}
}
});
@@ -175,6 +185,8 @@ public class TransportClient implements Closeable {
try {
return result.get(timeoutMs, TimeUnit.MILLISECONDS);
+ } catch (ExecutionException e) {
+ throw Throwables.propagate(e.getCause());
} catch (Exception e) {
throw Throwables.propagate(e);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 1723fed307..397d3a8455 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -18,12 +18,12 @@
package org.apache.spark.network.client;
import java.io.Closeable;
+import java.io.IOException;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import com.google.common.base.Preconditions;
@@ -44,7 +44,6 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.util.IOMode;
-import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;
@@ -93,15 +92,17 @@ public class TransportClientFactory implements Closeable {
*
* Concurrency: This method is safe to call from multiple threads.
*/
- public TransportClient createClient(String remoteHost, int remotePort) {
+ public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
TransportClient cachedClient = connectionPool.get(address);
if (cachedClient != null) {
if (cachedClient.isActive()) {
+ logger.trace("Returning cached connection to {}: {}", address, cachedClient);
return cachedClient;
} else {
+ logger.info("Found inactive connection to {}, closing it.", address);
connectionPool.remove(address, cachedClient); // Remove inactive clients.
}
}
@@ -133,10 +134,10 @@ public class TransportClientFactory implements Closeable {
long preConnect = System.currentTimeMillis();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
- throw new RuntimeException(
+ throw new IOException(
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
} else if (cf.cause() != null) {
- throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
+ throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
}
TransportClient client = clientRef.get();
@@ -198,7 +199,7 @@ public class TransportClientFactory implements Closeable {
*/
private PooledByteBufAllocator createPooledByteBufAllocator() {
return new PooledByteBufAllocator(
- PlatformDependent.directBufferPreferred(),
+ conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(),
getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
getPrivateStaticField("DEFAULT_PAGE_SIZE"),
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index d8965590b3..2044afb0d8 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -17,6 +17,7 @@
package org.apache.spark.network.client;
+import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@@ -94,7 +95,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
String remoteAddress = NettyUtils.getRemoteAddress(channel);
logger.error("Still have {} requests outstanding when connection from {} is closed",
numOutstandingRequests(), remoteAddress);
- failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed"));
+ failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 4cb8becc3e..91d1e8a538 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -66,7 +66,7 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
// All messages have the frame length, message type, and message itself.
int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
long frameLength = headerLength + bodyLength;
- ByteBuf header = ctx.alloc().buffer(headerLength);
+ ByteBuf header = ctx.alloc().heapBuffer(headerLength);
header.writeLong(frameLength);
msgType.encode(header);
in.encode(header);
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
index 70da48ca8e..579676c2c3 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -28,6 +28,7 @@ import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.SocketChannel;
+import io.netty.util.internal.PlatformDependent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -71,11 +72,14 @@ public class TransportServer implements Closeable {
NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server");
EventLoopGroup workerGroup = bossGroup;
+ PooledByteBufAllocator allocator = new PooledByteBufAllocator(
+ conf.preferDirectBufs() && PlatformDependent.directBufferPreferred());
+
bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
.channel(NettyUtils.getServerChannelClass(ioMode))
- .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
- .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT);
+ .option(ChannelOption.ALLOCATOR, allocator)
+ .childOption(ChannelOption.ALLOCATOR, allocator);
if (conf.backLog() > 0) {
bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());
diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
index b187234119..2a7664fe89 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
@@ -37,13 +37,17 @@ import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
* Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO.
*/
public class NettyUtils {
- /** Creates a Netty EventLoopGroup based on the IOMode. */
- public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
-
- ThreadFactory threadFactory = new ThreadFactoryBuilder()
+ /** Creates a new ThreadFactory which prefixes each thread with the given name. */
+ public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
+ return new ThreadFactoryBuilder()
.setDaemon(true)
- .setNameFormat(threadPrefix + "-%d")
+ .setNameFormat(threadPoolPrefix + "-%d")
.build();
+ }
+
+ /** Creates a Netty EventLoopGroup based on the IOMode. */
+ public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
+ ThreadFactory threadFactory = createThreadFactory(threadPrefix);
switch (mode) {
case NIO:
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 823790dd3c..787a8f0031 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -30,6 +30,11 @@ public class TransportConf {
/** IO mode: nio or epoll */
public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); }
+ /** If true, we will prefer allocating off-heap byte buffers within Netty. */
+ public boolean preferDirectBufs() {
+ return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true);
+ }
+
/** Connect timeout in secs. Default 120 secs. */
public int connectionTimeoutMs() {
return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
@@ -58,4 +63,16 @@ public class TransportConf {
/** Timeout for a single round trip of SASL token exchange, in milliseconds. */
public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); }
+
+ /**
+ * Max number of times we will try IO exceptions (such as connection timeouts) per request.
+ * If set to 0, we will not do any retries.
+ */
+ public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); }
+
+ /**
+ * Time (in milliseconds) that we will wait in order to perform a retry after an IOException.
+ * Only relevant if maxIORetries > 0.
+ */
+ public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.retryWaitMs", 5000); }
}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
index 5a10fdb384..822bef1d81 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.network;
+import java.io.IOException;
import java.util.concurrent.TimeoutException;
import org.junit.After;
@@ -57,7 +58,7 @@ public class TransportClientFactorySuite {
}
@Test
- public void createAndReuseBlockClients() throws TimeoutException {
+ public void createAndReuseBlockClients() throws IOException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
@@ -70,7 +71,7 @@ public class TransportClientFactorySuite {
}
@Test
- public void neverReturnInactiveClients() throws Exception {
+ public void neverReturnInactiveClients() throws IOException, InterruptedException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
c1.close();
@@ -88,7 +89,7 @@ public class TransportClientFactorySuite {
}
@Test
- public void closeBlockClientsWithFactory() throws TimeoutException {
+ public void closeBlockClientsWithFactory() throws IOException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 3aa95d00f6..27884b82c8 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -17,6 +17,7 @@
package org.apache.spark.network.shuffle;
+import java.io.IOException;
import java.util.List;
import com.google.common.collect.Lists;
@@ -76,17 +77,33 @@ public class ExternalShuffleClient extends ShuffleClient {
@Override
public void fetchBlocks(
- String host,
- int port,
- String execId,
+ final String host,
+ final int port,
+ final String execId,
String[] blockIds,
BlockFetchingListener listener) {
assert appId != null : "Called before init()";
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
- TransportClient client = clientFactory.createClient(host, port);
- new OneForOneBlockFetcher(client, blockIds, listener)
- .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
+ RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
+ new RetryingBlockFetcher.BlockFetchStarter() {
+ @Override
+ public void createAndStart(String[] blockIds, BlockFetchingListener listener)
+ throws IOException {
+ TransportClient client = clientFactory.createClient(host, port);
+ new OneForOneBlockFetcher(client, blockIds, listener)
+ .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
+ }
+ };
+
+ int maxRetries = conf.maxIORetries();
+ if (maxRetries > 0) {
+ // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
+ // a bug in this code. We should remove the if statement once we're sure of the stability.
+ new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start();
+ } else {
+ blockFetchStarter.createAndStart(blockIds, listener);
+ }
} catch (Exception e) {
logger.error("Exception while beginning fetchBlocks", e);
for (String blockId : blockIds) {
@@ -108,7 +125,7 @@ public class ExternalShuffleClient extends ShuffleClient {
String host,
int port,
String execId,
- ExecutorShuffleInfo executorInfo) {
+ ExecutorShuffleInfo executorInfo) throws IOException {
assert appId != null : "Called before init()";
TransportClient client = clientFactory.createClient(host, port);
byte[] registerExecutorMessage =
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index 39b6f30f92..9e77a1f68c 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -51,9 +51,6 @@ public class OneForOneBlockFetcher {
TransportClient client,
String[] blockIds,
BlockFetchingListener listener) {
- if (blockIds.length == 0) {
- throw new IllegalArgumentException("Zero-sized blockIds array");
- }
this.client = client;
this.blockIds = blockIds;
this.listener = listener;
@@ -82,6 +79,10 @@ public class OneForOneBlockFetcher {
* {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling.
*/
public void start(Object openBlocksMessage) {
+ if (blockIds.length == 0) {
+ throw new IllegalArgumentException("Zero-sized blockIds array");
+ }
+
client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() {
@Override
public void onSuccess(byte[] response) {
@@ -95,7 +96,7 @@ public class OneForOneBlockFetcher {
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
}
} catch (Exception e) {
- logger.error("Failed while starting block fetches", e);
+ logger.error("Failed while starting block fetches after success", e);
failRemainingBlocks(blockIds, e);
}
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
new file mode 100644
index 0000000000..f8a1a26686
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.LinkedHashSet;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Sets;
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Wraps another BlockFetcher with the ability to automatically retry fetches which fail due to
+ * IOExceptions, which we hope are due to transient network conditions.
+ *
+ * This fetcher provides stronger guarantees regarding the parent BlockFetchingListener. In
+ * particular, the listener will be invoked exactly once per blockId, with a success or failure.
+ */
+public class RetryingBlockFetcher {
+
+ /**
+ * Used to initiate the first fetch for all blocks, and subsequently for retrying the fetch on any
+ * remaining blocks.
+ */
+ public static interface BlockFetchStarter {
+ /**
+ * Creates a new BlockFetcher to fetch the given block ids which may do some synchronous
+ * bootstrapping followed by fully asynchronous block fetching.
+ * The BlockFetcher must eventually invoke the Listener on every input blockId, or else this
+ * method must throw an exception.
+ *
+ * This method should always attempt to get a new TransportClient from the
+ * {@link org.apache.spark.network.client.TransportClientFactory} in order to fix connection
+ * issues.
+ */
+ void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException;
+ }
+
+ /** Shared executor service used for waiting and retrying. */
+ private static final ExecutorService executorService = Executors.newCachedThreadPool(
+ NettyUtils.createThreadFactory("Block Fetch Retry"));
+
+ private final Logger logger = LoggerFactory.getLogger(RetryingBlockFetcher.class);
+
+ /** Used to initiate new Block Fetches on our remaining blocks. */
+ private final BlockFetchStarter fetchStarter;
+
+ /** Parent listener which we delegate all successful or permanently failed block fetches to. */
+ private final BlockFetchingListener listener;
+
+ /** Max number of times we are allowed to retry. */
+ private final int maxRetries;
+
+ /** Milliseconds to wait before each retry. */
+ private final int retryWaitTime;
+
+ // NOTE:
+ // All of our non-final fields are synchronized under 'this' and should only be accessed/mutated
+ // while inside a synchronized block.
+ /** Number of times we've attempted to retry so far. */
+ private int retryCount = 0;
+
+ /**
+ * Set of all block ids which have not been fetched successfully or with a non-IO Exception.
+ * A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet,
+ * input ordering is preserved, so we always request blocks in the same order the user provided.
+ */
+ private final LinkedHashSet<String> outstandingBlocksIds;
+
+ /**
+ * The BlockFetchingListener that is active with our current BlockFetcher.
+ * When we start a retry, we immediately replace this with a new Listener, which causes all any
+ * old Listeners to ignore all further responses.
+ */
+ private RetryingBlockFetchListener currentListener;
+
+ public RetryingBlockFetcher(
+ TransportConf conf,
+ BlockFetchStarter fetchStarter,
+ String[] blockIds,
+ BlockFetchingListener listener) {
+ this.fetchStarter = fetchStarter;
+ this.listener = listener;
+ this.maxRetries = conf.maxIORetries();
+ this.retryWaitTime = conf.ioRetryWaitTime();
+ this.outstandingBlocksIds = Sets.newLinkedHashSet();
+ Collections.addAll(outstandingBlocksIds, blockIds);
+ this.currentListener = new RetryingBlockFetchListener();
+ }
+
+ /**
+ * Initiates the fetch of all blocks provided in the constructor, with possible retries in the
+ * event of transient IOExceptions.
+ */
+ public void start() {
+ fetchAllOutstanding();
+ }
+
+ /**
+ * Fires off a request to fetch all blocks that have not been fetched successfully or permanently
+ * failed (i.e., by a non-IOException).
+ */
+ private void fetchAllOutstanding() {
+ // Start by retrieving our shared state within a synchronized block.
+ String[] blockIdsToFetch;
+ int numRetries;
+ RetryingBlockFetchListener myListener;
+ synchronized (this) {
+ blockIdsToFetch = outstandingBlocksIds.toArray(new String[outstandingBlocksIds.size()]);
+ numRetries = retryCount;
+ myListener = currentListener;
+ }
+
+ // Now initiate the fetch on all outstanding blocks, possibly initiating a retry if that fails.
+ try {
+ fetchStarter.createAndStart(blockIdsToFetch, myListener);
+ } catch (Exception e) {
+ logger.error(String.format("Exception while beginning fetch of %s outstanding blocks %s",
+ blockIdsToFetch.length, numRetries > 0 ? "(after " + numRetries + " retries)" : ""), e);
+
+ if (shouldRetry(e)) {
+ initiateRetry();
+ } else {
+ for (String bid : blockIdsToFetch) {
+ listener.onBlockFetchFailure(bid, e);
+ }
+ }
+ }
+ }
+
+ /**
+ * Lightweight method which initiates a retry in a different thread. The retry will involve
+ * calling fetchAllOutstanding() after a configured wait time.
+ */
+ private synchronized void initiateRetry() {
+ retryCount += 1;
+ currentListener = new RetryingBlockFetchListener();
+
+ logger.info("Retrying fetch ({}/{}) for {} outstanding blocks after {} ms",
+ retryCount, maxRetries, outstandingBlocksIds.size(), retryWaitTime);
+
+ executorService.submit(new Runnable() {
+ @Override
+ public void run() {
+ Uninterruptibles.sleepUninterruptibly(retryWaitTime, TimeUnit.MILLISECONDS);
+ fetchAllOutstanding();
+ }
+ });
+ }
+
+ /**
+ * Returns true if we should retry due a block fetch failure. We will retry if and only if
+ * the exception was an IOException and we haven't retried 'maxRetries' times already.
+ */
+ private synchronized boolean shouldRetry(Throwable e) {
+ boolean isIOException = e instanceof IOException
+ || (e.getCause() != null && e.getCause() instanceof IOException);
+ boolean hasRemainingRetries = retryCount < maxRetries;
+ return isIOException && hasRemainingRetries;
+ }
+
+ /**
+ * Our RetryListener intercepts block fetch responses and forwards them to our parent listener.
+ * Note that in the event of a retry, we will immediately replace the 'currentListener' field,
+ * indicating that any responses from non-current Listeners should be ignored.
+ */
+ private class RetryingBlockFetchListener implements BlockFetchingListener {
+ @Override
+ public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
+ // We will only forward this success message to our parent listener if this block request is
+ // outstanding and we are still the active listener.
+ boolean shouldForwardSuccess = false;
+ synchronized (RetryingBlockFetcher.this) {
+ if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
+ outstandingBlocksIds.remove(blockId);
+ shouldForwardSuccess = true;
+ }
+ }
+
+ // Now actually invoke the parent listener, outside of the synchronized block.
+ if (shouldForwardSuccess) {
+ listener.onBlockFetchSuccess(blockId, data);
+ }
+ }
+
+ @Override
+ public void onBlockFetchFailure(String blockId, Throwable exception) {
+ // We will only forward this failure to our parent listener if this block request is
+ // outstanding, we are still the active listener, AND we cannot retry the fetch.
+ boolean shouldForwardFailure = false;
+ synchronized (RetryingBlockFetcher.this) {
+ if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
+ if (shouldRetry(exception)) {
+ initiateRetry();
+ } else {
+ logger.error(String.format("Failed to fetch block %s, and will not retry (%s retries)",
+ blockId, retryCount), exception);
+ outstandingBlocksIds.remove(blockId);
+ shouldForwardFailure = true;
+ }
+ }
+ }
+
+ // Now actually invoke the parent listener, outside of the synchronized block.
+ if (shouldForwardFailure) {
+ listener.onBlockFetchFailure(blockId, exception);
+ }
+ }
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index 8478120786..d25283e46e 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -93,7 +93,7 @@ public class SaslIntegrationSuite {
}
@Test
- public void testGoodClient() {
+ public void testGoodClient() throws IOException {
clientFactory = context.createClientFactory(
Lists.<TransportClientBootstrap>newArrayList(
new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key"))));
@@ -119,7 +119,7 @@ public class SaslIntegrationSuite {
}
@Test
- public void testNoSaslClient() {
+ public void testNoSaslClient() throws IOException {
clientFactory = context.createClientFactory(
Lists.<TransportClientBootstrap>newArrayList());
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index 71e017b9e4..06294fef19 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -259,14 +259,20 @@ public class ExternalShuffleIntegrationSuite {
@Test
public void testFetchNoServer() throws Exception {
- registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
- FetchResult execFetch = fetchBlocks("exec-0",
- new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }, 1 /* port */);
- assertTrue(execFetch.successBlocks.isEmpty());
- assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks);
+ System.setProperty("spark.shuffle.io.maxRetries", "0");
+ try {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-0",
+ new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, 1 /* port */);
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks);
+ } finally {
+ System.clearProperty("spark.shuffle.io.maxRetries");
+ }
}
- private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) {
+ private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo)
+ throws IOException {
ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
client.init(APP_ID);
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
index 4c18fcdfbc..848c88f743 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.shuffle;
+import java.io.IOException;
+
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -54,7 +56,7 @@ public class ExternalShuffleSecuritySuite {
}
@Test
- public void testValid() {
+ public void testValid() throws IOException {
validate("my-app-id", "secret");
}
@@ -77,7 +79,7 @@ public class ExternalShuffleSecuritySuite {
}
/** Creates an ExternalShuffleClient and attempts to register with the server. */
- private void validate(String appId, String secretKey) {
+ private void validate(String appId, String secretKey) throws IOException {
ExternalShuffleClient client =
new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true);
client.init(appId);
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
new file mode 100644
index 0000000000..0191fe529e
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.LinkedHashSet;
+import java.util.Map;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Sets;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import org.mockito.stubbing.Stubber;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+import static org.apache.spark.network.shuffle.RetryingBlockFetcher.BlockFetchStarter;
+
+/**
+ * Tests retry logic by throwing IOExceptions and ensuring that subsequent attempts are made to
+ * fetch the lost blocks.
+ */
+public class RetryingBlockFetcherSuite {
+
+ ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13]));
+ ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
+ ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19]));
+
+ @Before
+ public void beforeEach() {
+ System.setProperty("spark.shuffle.io.maxRetries", "2");
+ System.setProperty("spark.shuffle.io.retryWaitMs", "0");
+ }
+
+ @After
+ public void afterEach() {
+ System.clearProperty("spark.shuffle.io.maxRetries");
+ System.clearProperty("spark.shuffle.io.retryWaitMs");
+ }
+
+ @Test
+ public void testNoFailures() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ Map[] interactions = new Map[] {
+ // Immediately return both blocks successfully.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", block1)
+ .build(),
+ };
+
+ performInteractions(interactions, listener);
+
+ verify(listener).onBlockFetchSuccess("b0", block0);
+ verify(listener).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testUnrecoverableFailure() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ Map[] interactions = new Map[] {
+ // b0 throws a non-IOException error, so it will be failed without retry.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new RuntimeException("Ouch!"))
+ .put("b1", block1)
+ .build(),
+ };
+
+ performInteractions(interactions, listener);
+
+ verify(listener).onBlockFetchFailure(eq("b0"), (Throwable) any());
+ verify(listener).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testSingleIOExceptionOnFirst() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ Map[] interactions = new Map[] {
+ // IOException will cause a retry. Since b0 fails, we will retry both.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new IOException("Connection failed or something"))
+ .put("b1", block1)
+ .build(),
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", block1)
+ .build(),
+ };
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testSingleIOExceptionOnSecond() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ Map[] interactions = new Map[] {
+ // IOException will cause a retry. Since b1 fails, we will not retry b0.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", new IOException("Connection failed or something"))
+ .build(),
+ ImmutableMap.<String, Object>builder()
+ .put("b1", block1)
+ .build(),
+ };
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testTwoIOExceptions() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ Map[] interactions = new Map[] {
+ // b0's IOException will trigger retry, b1's will be ignored.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new IOException())
+ .put("b1", new IOException())
+ .build(),
+ // Next, b0 is successful and b1 errors again, so we just request that one.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", new IOException())
+ .build(),
+ // b1 returns successfully within 2 retries.
+ ImmutableMap.<String, Object>builder()
+ .put("b1", block1)
+ .build(),
+ };
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testThreeIOExceptions() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ Map[] interactions = new Map[] {
+ // b0's IOException will trigger retry, b1's will be ignored.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new IOException())
+ .put("b1", new IOException())
+ .build(),
+ // Next, b0 is successful and b1 errors again, so we just request that one.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", new IOException())
+ .build(),
+ // b1 errors again, but this was the last retry
+ ImmutableMap.<String, Object>builder()
+ .put("b1", new IOException())
+ .build(),
+ // This is not reached -- b1 has failed.
+ ImmutableMap.<String, Object>builder()
+ .put("b1", block1)
+ .build(),
+ };
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any());
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testRetryAndUnrecoverable() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ Map[] interactions = new Map[] {
+ // b0's IOException will trigger retry, subsequent messages will be ignored.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new IOException())
+ .put("b1", new RuntimeException())
+ .put("b2", block2)
+ .build(),
+ // Next, b0 is successful, b1 errors unrecoverably, and b2 triggers a retry.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", new RuntimeException())
+ .put("b2", new IOException())
+ .build(),
+ // b2 succeeds in its last retry.
+ ImmutableMap.<String, Object>builder()
+ .put("b2", block2)
+ .build(),
+ };
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any());
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2);
+ verifyNoMoreInteractions(listener);
+ }
+
+ /**
+ * Performs a set of interactions in response to block requests from a RetryingBlockFetcher.
+ * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction
+ * means "respond to the next block fetch request with these Successful buffers and these Failure
+ * exceptions". We verify that the expected block ids are exactly the ones requested.
+ *
+ * If multiple interactions are supplied, they will be used in order. This is useful for encoding
+ * retries -- the first interaction may include an IOException, which causes a retry of some
+ * subset of the original blocks in a second interaction.
+ */
+ @SuppressWarnings("unchecked")
+ private void performInteractions(final Map[] interactions, BlockFetchingListener listener)
+ throws IOException {
+
+ TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+ BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class);
+
+ Stubber stub = null;
+
+ // Contains all blockIds that are referenced across all interactions.
+ final LinkedHashSet<String> blockIds = Sets.newLinkedHashSet();
+
+ for (final Map<String, Object> interaction : interactions) {
+ blockIds.addAll(interaction.keySet());
+
+ Answer<Void> answer = new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+ try {
+ // Verify that the RetryingBlockFetcher requested the expected blocks.
+ String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0];
+ String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]);
+ assertArrayEquals(desiredBlockIds, requestedBlockIds);
+
+ // Now actually invoke the success/failure callbacks on each block.
+ BlockFetchingListener retryListener =
+ (BlockFetchingListener) invocationOnMock.getArguments()[1];
+ for (Map.Entry<String, Object> block : interaction.entrySet()) {
+ String blockId = block.getKey();
+ Object blockValue = block.getValue();
+
+ if (blockValue instanceof ManagedBuffer) {
+ retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue);
+ } else if (blockValue instanceof Exception) {
+ retryListener.onBlockFetchFailure(blockId, (Exception) blockValue);
+ } else {
+ fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue);
+ }
+ }
+ return null;
+ } catch (Throwable e) {
+ e.printStackTrace();
+ throw e;
+ }
+ }
+ };
+
+ // This is either the first stub, or should be chained behind the prior ones.
+ if (stub == null) {
+ stub = doAnswer(answer);
+ } else {
+ stub.doAnswer(answer);
+ }
+ }
+
+ assert stub != null;
+ stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject());
+ String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]);
+ new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start();
+ }
+}