aboutsummaryrefslogtreecommitdiff
path: root/network/shuffle
diff options
context:
space:
mode:
Diffstat (limited to 'network/shuffle')
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java31
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java9
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java234
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java4
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java18
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java6
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java310
7 files changed, 591 insertions, 21 deletions
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 3aa95d00f6..27884b82c8 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -17,6 +17,7 @@
package org.apache.spark.network.shuffle;
+import java.io.IOException;
import java.util.List;
import com.google.common.collect.Lists;
@@ -76,17 +77,33 @@ public class ExternalShuffleClient extends ShuffleClient {
@Override
public void fetchBlocks(
- String host,
- int port,
- String execId,
+ final String host,
+ final int port,
+ final String execId,
String[] blockIds,
BlockFetchingListener listener) {
assert appId != null : "Called before init()";
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
- TransportClient client = clientFactory.createClient(host, port);
- new OneForOneBlockFetcher(client, blockIds, listener)
- .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
+ RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
+ new RetryingBlockFetcher.BlockFetchStarter() {
+ @Override
+ public void createAndStart(String[] blockIds, BlockFetchingListener listener)
+ throws IOException {
+ TransportClient client = clientFactory.createClient(host, port);
+ new OneForOneBlockFetcher(client, blockIds, listener)
+ .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
+ }
+ };
+
+ int maxRetries = conf.maxIORetries();
+ if (maxRetries > 0) {
+ // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
+ // a bug in this code. We should remove the if statement once we're sure of the stability.
+ new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start();
+ } else {
+ blockFetchStarter.createAndStart(blockIds, listener);
+ }
} catch (Exception e) {
logger.error("Exception while beginning fetchBlocks", e);
for (String blockId : blockIds) {
@@ -108,7 +125,7 @@ public class ExternalShuffleClient extends ShuffleClient {
String host,
int port,
String execId,
- ExecutorShuffleInfo executorInfo) {
+ ExecutorShuffleInfo executorInfo) throws IOException {
assert appId != null : "Called before init()";
TransportClient client = clientFactory.createClient(host, port);
byte[] registerExecutorMessage =
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index 39b6f30f92..9e77a1f68c 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -51,9 +51,6 @@ public class OneForOneBlockFetcher {
TransportClient client,
String[] blockIds,
BlockFetchingListener listener) {
- if (blockIds.length == 0) {
- throw new IllegalArgumentException("Zero-sized blockIds array");
- }
this.client = client;
this.blockIds = blockIds;
this.listener = listener;
@@ -82,6 +79,10 @@ public class OneForOneBlockFetcher {
* {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling.
*/
public void start(Object openBlocksMessage) {
+ if (blockIds.length == 0) {
+ throw new IllegalArgumentException("Zero-sized blockIds array");
+ }
+
client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() {
@Override
public void onSuccess(byte[] response) {
@@ -95,7 +96,7 @@ public class OneForOneBlockFetcher {
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
}
} catch (Exception e) {
- logger.error("Failed while starting block fetches", e);
+ logger.error("Failed while starting block fetches after success", e);
failRemainingBlocks(blockIds, e);
}
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
new file mode 100644
index 0000000000..f8a1a26686
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.LinkedHashSet;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Sets;
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Wraps another BlockFetcher with the ability to automatically retry fetches which fail due to
+ * IOExceptions, which we hope are due to transient network conditions.
+ *
+ * This fetcher provides stronger guarantees regarding the parent BlockFetchingListener. In
+ * particular, the listener will be invoked exactly once per blockId, with a success or failure.
+ */
+public class RetryingBlockFetcher {
+
+ /**
+ * Used to initiate the first fetch for all blocks, and subsequently for retrying the fetch on any
+ * remaining blocks.
+ */
+ public static interface BlockFetchStarter {
+ /**
+ * Creates a new BlockFetcher to fetch the given block ids which may do some synchronous
+ * bootstrapping followed by fully asynchronous block fetching.
+ * The BlockFetcher must eventually invoke the Listener on every input blockId, or else this
+ * method must throw an exception.
+ *
+ * This method should always attempt to get a new TransportClient from the
+ * {@link org.apache.spark.network.client.TransportClientFactory} in order to fix connection
+ * issues.
+ */
+ void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException;
+ }
+
+ /** Shared executor service used for waiting and retrying. */
+ private static final ExecutorService executorService = Executors.newCachedThreadPool(
+ NettyUtils.createThreadFactory("Block Fetch Retry"));
+
+ private final Logger logger = LoggerFactory.getLogger(RetryingBlockFetcher.class);
+
+ /** Used to initiate new Block Fetches on our remaining blocks. */
+ private final BlockFetchStarter fetchStarter;
+
+ /** Parent listener which we delegate all successful or permanently failed block fetches to. */
+ private final BlockFetchingListener listener;
+
+ /** Max number of times we are allowed to retry. */
+ private final int maxRetries;
+
+ /** Milliseconds to wait before each retry. */
+ private final int retryWaitTime;
+
+ // NOTE:
+ // All of our non-final fields are synchronized under 'this' and should only be accessed/mutated
+ // while inside a synchronized block.
+ /** Number of times we've attempted to retry so far. */
+ private int retryCount = 0;
+
+ /**
+ * Set of all block ids which have not been fetched successfully or with a non-IO Exception.
+ * A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet,
+ * input ordering is preserved, so we always request blocks in the same order the user provided.
+ */
+ private final LinkedHashSet<String> outstandingBlocksIds;
+
+ /**
+ * The BlockFetchingListener that is active with our current BlockFetcher.
+ * When we start a retry, we immediately replace this with a new Listener, which causes all any
+ * old Listeners to ignore all further responses.
+ */
+ private RetryingBlockFetchListener currentListener;
+
+ public RetryingBlockFetcher(
+ TransportConf conf,
+ BlockFetchStarter fetchStarter,
+ String[] blockIds,
+ BlockFetchingListener listener) {
+ this.fetchStarter = fetchStarter;
+ this.listener = listener;
+ this.maxRetries = conf.maxIORetries();
+ this.retryWaitTime = conf.ioRetryWaitTime();
+ this.outstandingBlocksIds = Sets.newLinkedHashSet();
+ Collections.addAll(outstandingBlocksIds, blockIds);
+ this.currentListener = new RetryingBlockFetchListener();
+ }
+
+ /**
+ * Initiates the fetch of all blocks provided in the constructor, with possible retries in the
+ * event of transient IOExceptions.
+ */
+ public void start() {
+ fetchAllOutstanding();
+ }
+
+ /**
+ * Fires off a request to fetch all blocks that have not been fetched successfully or permanently
+ * failed (i.e., by a non-IOException).
+ */
+ private void fetchAllOutstanding() {
+ // Start by retrieving our shared state within a synchronized block.
+ String[] blockIdsToFetch;
+ int numRetries;
+ RetryingBlockFetchListener myListener;
+ synchronized (this) {
+ blockIdsToFetch = outstandingBlocksIds.toArray(new String[outstandingBlocksIds.size()]);
+ numRetries = retryCount;
+ myListener = currentListener;
+ }
+
+ // Now initiate the fetch on all outstanding blocks, possibly initiating a retry if that fails.
+ try {
+ fetchStarter.createAndStart(blockIdsToFetch, myListener);
+ } catch (Exception e) {
+ logger.error(String.format("Exception while beginning fetch of %s outstanding blocks %s",
+ blockIdsToFetch.length, numRetries > 0 ? "(after " + numRetries + " retries)" : ""), e);
+
+ if (shouldRetry(e)) {
+ initiateRetry();
+ } else {
+ for (String bid : blockIdsToFetch) {
+ listener.onBlockFetchFailure(bid, e);
+ }
+ }
+ }
+ }
+
+ /**
+ * Lightweight method which initiates a retry in a different thread. The retry will involve
+ * calling fetchAllOutstanding() after a configured wait time.
+ */
+ private synchronized void initiateRetry() {
+ retryCount += 1;
+ currentListener = new RetryingBlockFetchListener();
+
+ logger.info("Retrying fetch ({}/{}) for {} outstanding blocks after {} ms",
+ retryCount, maxRetries, outstandingBlocksIds.size(), retryWaitTime);
+
+ executorService.submit(new Runnable() {
+ @Override
+ public void run() {
+ Uninterruptibles.sleepUninterruptibly(retryWaitTime, TimeUnit.MILLISECONDS);
+ fetchAllOutstanding();
+ }
+ });
+ }
+
+ /**
+ * Returns true if we should retry due a block fetch failure. We will retry if and only if
+ * the exception was an IOException and we haven't retried 'maxRetries' times already.
+ */
+ private synchronized boolean shouldRetry(Throwable e) {
+ boolean isIOException = e instanceof IOException
+ || (e.getCause() != null && e.getCause() instanceof IOException);
+ boolean hasRemainingRetries = retryCount < maxRetries;
+ return isIOException && hasRemainingRetries;
+ }
+
+ /**
+ * Our RetryListener intercepts block fetch responses and forwards them to our parent listener.
+ * Note that in the event of a retry, we will immediately replace the 'currentListener' field,
+ * indicating that any responses from non-current Listeners should be ignored.
+ */
+ private class RetryingBlockFetchListener implements BlockFetchingListener {
+ @Override
+ public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
+ // We will only forward this success message to our parent listener if this block request is
+ // outstanding and we are still the active listener.
+ boolean shouldForwardSuccess = false;
+ synchronized (RetryingBlockFetcher.this) {
+ if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
+ outstandingBlocksIds.remove(blockId);
+ shouldForwardSuccess = true;
+ }
+ }
+
+ // Now actually invoke the parent listener, outside of the synchronized block.
+ if (shouldForwardSuccess) {
+ listener.onBlockFetchSuccess(blockId, data);
+ }
+ }
+
+ @Override
+ public void onBlockFetchFailure(String blockId, Throwable exception) {
+ // We will only forward this failure to our parent listener if this block request is
+ // outstanding, we are still the active listener, AND we cannot retry the fetch.
+ boolean shouldForwardFailure = false;
+ synchronized (RetryingBlockFetcher.this) {
+ if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
+ if (shouldRetry(exception)) {
+ initiateRetry();
+ } else {
+ logger.error(String.format("Failed to fetch block %s, and will not retry (%s retries)",
+ blockId, retryCount), exception);
+ outstandingBlocksIds.remove(blockId);
+ shouldForwardFailure = true;
+ }
+ }
+ }
+
+ // Now actually invoke the parent listener, outside of the synchronized block.
+ if (shouldForwardFailure) {
+ listener.onBlockFetchFailure(blockId, exception);
+ }
+ }
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index 8478120786..d25283e46e 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -93,7 +93,7 @@ public class SaslIntegrationSuite {
}
@Test
- public void testGoodClient() {
+ public void testGoodClient() throws IOException {
clientFactory = context.createClientFactory(
Lists.<TransportClientBootstrap>newArrayList(
new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key"))));
@@ -119,7 +119,7 @@ public class SaslIntegrationSuite {
}
@Test
- public void testNoSaslClient() {
+ public void testNoSaslClient() throws IOException {
clientFactory = context.createClientFactory(
Lists.<TransportClientBootstrap>newArrayList());
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index 71e017b9e4..06294fef19 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -259,14 +259,20 @@ public class ExternalShuffleIntegrationSuite {
@Test
public void testFetchNoServer() throws Exception {
- registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
- FetchResult execFetch = fetchBlocks("exec-0",
- new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }, 1 /* port */);
- assertTrue(execFetch.successBlocks.isEmpty());
- assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks);
+ System.setProperty("spark.shuffle.io.maxRetries", "0");
+ try {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-0",
+ new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, 1 /* port */);
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks);
+ } finally {
+ System.clearProperty("spark.shuffle.io.maxRetries");
+ }
}
- private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) {
+ private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo)
+ throws IOException {
ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
client.init(APP_ID);
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
index 4c18fcdfbc..848c88f743 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.shuffle;
+import java.io.IOException;
+
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -54,7 +56,7 @@ public class ExternalShuffleSecuritySuite {
}
@Test
- public void testValid() {
+ public void testValid() throws IOException {
validate("my-app-id", "secret");
}
@@ -77,7 +79,7 @@ public class ExternalShuffleSecuritySuite {
}
/** Creates an ExternalShuffleClient and attempts to register with the server. */
- private void validate(String appId, String secretKey) {
+ private void validate(String appId, String secretKey) throws IOException {
ExternalShuffleClient client =
new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true);
client.init(appId);
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
new file mode 100644
index 0000000000..0191fe529e
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.LinkedHashSet;
+import java.util.Map;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Sets;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import org.mockito.stubbing.Stubber;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+import static org.apache.spark.network.shuffle.RetryingBlockFetcher.BlockFetchStarter;
+
+/**
+ * Tests retry logic by throwing IOExceptions and ensuring that subsequent attempts are made to
+ * fetch the lost blocks.
+ */
+public class RetryingBlockFetcherSuite {
+
+ ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13]));
+ ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
+ ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19]));
+
+ @Before
+ public void beforeEach() {
+ System.setProperty("spark.shuffle.io.maxRetries", "2");
+ System.setProperty("spark.shuffle.io.retryWaitMs", "0");
+ }
+
+ @After
+ public void afterEach() {
+ System.clearProperty("spark.shuffle.io.maxRetries");
+ System.clearProperty("spark.shuffle.io.retryWaitMs");
+ }
+
+ @Test
+ public void testNoFailures() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ Map[] interactions = new Map[] {
+ // Immediately return both blocks successfully.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", block1)
+ .build(),
+ };
+
+ performInteractions(interactions, listener);
+
+ verify(listener).onBlockFetchSuccess("b0", block0);
+ verify(listener).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testUnrecoverableFailure() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ Map[] interactions = new Map[] {
+ // b0 throws a non-IOException error, so it will be failed without retry.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new RuntimeException("Ouch!"))
+ .put("b1", block1)
+ .build(),
+ };
+
+ performInteractions(interactions, listener);
+
+ verify(listener).onBlockFetchFailure(eq("b0"), (Throwable) any());
+ verify(listener).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testSingleIOExceptionOnFirst() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ Map[] interactions = new Map[] {
+ // IOException will cause a retry. Since b0 fails, we will retry both.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new IOException("Connection failed or something"))
+ .put("b1", block1)
+ .build(),
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", block1)
+ .build(),
+ };
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testSingleIOExceptionOnSecond() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ Map[] interactions = new Map[] {
+ // IOException will cause a retry. Since b1 fails, we will not retry b0.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", new IOException("Connection failed or something"))
+ .build(),
+ ImmutableMap.<String, Object>builder()
+ .put("b1", block1)
+ .build(),
+ };
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testTwoIOExceptions() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ Map[] interactions = new Map[] {
+ // b0's IOException will trigger retry, b1's will be ignored.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new IOException())
+ .put("b1", new IOException())
+ .build(),
+ // Next, b0 is successful and b1 errors again, so we just request that one.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", new IOException())
+ .build(),
+ // b1 returns successfully within 2 retries.
+ ImmutableMap.<String, Object>builder()
+ .put("b1", block1)
+ .build(),
+ };
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testThreeIOExceptions() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ Map[] interactions = new Map[] {
+ // b0's IOException will trigger retry, b1's will be ignored.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new IOException())
+ .put("b1", new IOException())
+ .build(),
+ // Next, b0 is successful and b1 errors again, so we just request that one.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", new IOException())
+ .build(),
+ // b1 errors again, but this was the last retry
+ ImmutableMap.<String, Object>builder()
+ .put("b1", new IOException())
+ .build(),
+ // This is not reached -- b1 has failed.
+ ImmutableMap.<String, Object>builder()
+ .put("b1", block1)
+ .build(),
+ };
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any());
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testRetryAndUnrecoverable() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ Map[] interactions = new Map[] {
+ // b0's IOException will trigger retry, subsequent messages will be ignored.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new IOException())
+ .put("b1", new RuntimeException())
+ .put("b2", block2)
+ .build(),
+ // Next, b0 is successful, b1 errors unrecoverably, and b2 triggers a retry.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", new RuntimeException())
+ .put("b2", new IOException())
+ .build(),
+ // b2 succeeds in its last retry.
+ ImmutableMap.<String, Object>builder()
+ .put("b2", block2)
+ .build(),
+ };
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any());
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2);
+ verifyNoMoreInteractions(listener);
+ }
+
+ /**
+ * Performs a set of interactions in response to block requests from a RetryingBlockFetcher.
+ * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction
+ * means "respond to the next block fetch request with these Successful buffers and these Failure
+ * exceptions". We verify that the expected block ids are exactly the ones requested.
+ *
+ * If multiple interactions are supplied, they will be used in order. This is useful for encoding
+ * retries -- the first interaction may include an IOException, which causes a retry of some
+ * subset of the original blocks in a second interaction.
+ */
+ @SuppressWarnings("unchecked")
+ private void performInteractions(final Map[] interactions, BlockFetchingListener listener)
+ throws IOException {
+
+ TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+ BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class);
+
+ Stubber stub = null;
+
+ // Contains all blockIds that are referenced across all interactions.
+ final LinkedHashSet<String> blockIds = Sets.newLinkedHashSet();
+
+ for (final Map<String, Object> interaction : interactions) {
+ blockIds.addAll(interaction.keySet());
+
+ Answer<Void> answer = new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+ try {
+ // Verify that the RetryingBlockFetcher requested the expected blocks.
+ String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0];
+ String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]);
+ assertArrayEquals(desiredBlockIds, requestedBlockIds);
+
+ // Now actually invoke the success/failure callbacks on each block.
+ BlockFetchingListener retryListener =
+ (BlockFetchingListener) invocationOnMock.getArguments()[1];
+ for (Map.Entry<String, Object> block : interaction.entrySet()) {
+ String blockId = block.getKey();
+ Object blockValue = block.getValue();
+
+ if (blockValue instanceof ManagedBuffer) {
+ retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue);
+ } else if (blockValue instanceof Exception) {
+ retryListener.onBlockFetchFailure(blockId, (Exception) blockValue);
+ } else {
+ fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue);
+ }
+ }
+ return null;
+ } catch (Throwable e) {
+ e.printStackTrace();
+ throw e;
+ }
+ }
+ };
+
+ // This is either the first stub, or should be chained behind the prior ones.
+ if (stub == null) {
+ stub = doAnswer(answer);
+ } else {
+ stub.doAnswer(answer);
+ }
+ }
+
+ assert stub != null;
+ stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject());
+ String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]);
+ new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start();
+ }
+}