aboutsummaryrefslogtreecommitdiff
path: root/network
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2015-09-02 12:53:24 -0700
committerMarcelo Vanzin <vanzin@cloudera.com>2015-09-02 12:53:24 -0700
commit2da3a9e98e5d129d4507b5db01bba5ee9558d28e (patch)
treec5197f543f18959d793db1caea4ee553acef4f97 /network
parentfc48307797912dc1d53893dce741ddda8630957b (diff)
downloadspark-2da3a9e98e5d129d4507b5db01bba5ee9558d28e.tar.gz
spark-2da3a9e98e5d129d4507b5db01bba5ee9558d28e.tar.bz2
spark-2da3a9e98e5d129d4507b5db01bba5ee9558d28e.zip
[SPARK-10004] [SHUFFLE] Perform auth checks when clients read shuffle data.
To correctly isolate applications, when requests to read shuffle data arrive at the shuffle service, proper authorization checks need to be performed. This change makes sure that only the application that created the shuffle data can read from it. Such checks are only enabled when "spark.authenticate" is enabled, otherwise there's no secure way to make sure that the client is really who it says it is. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #8218 from vanzin/SPARK-10004.
Diffstat (limited to 'network')
-rw-r--r--network/common/pom.xml4
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java22
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java1
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java31
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/StreamManager.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java1
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java16
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java163
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java2
10 files changed, 217 insertions, 34 deletions
diff --git a/network/common/pom.xml b/network/common/pom.xml
index 7dc3068ab8..4141fcb826 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -48,6 +48,10 @@
<artifactId>slf4j-api</artifactId>
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>com.google.code.findbugs</groupId>
+ <artifactId>jsr305</artifactId>
+ </dependency>
<!--
Promote Guava to "compile" so that maven-shade-plugin picks it up (for packaging the Optional
class exposed in the Java API). The plugin will then remove this dependency from the published
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index e8e7f06247..df841288a0 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -23,6 +23,7 @@ import java.net.SocketAddress;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
@@ -70,6 +71,7 @@ public class TransportClient implements Closeable {
private final Channel channel;
private final TransportResponseHandler handler;
+ @Nullable private String clientId;
public TransportClient(Channel channel, TransportResponseHandler handler) {
this.channel = Preconditions.checkNotNull(channel);
@@ -85,6 +87,25 @@ public class TransportClient implements Closeable {
}
/**
+ * Returns the ID used by the client to authenticate itself when authentication is enabled.
+ *
+ * @return The client ID, or null if authentication is disabled.
+ */
+ public String getClientId() {
+ return clientId;
+ }
+
+ /**
+ * Sets the authenticated client ID. This is meant to be used by the authentication layer.
+ *
+ * Trying to set a different client ID after it's been set will result in an exception.
+ */
+ public void setClientId(String id) {
+ Preconditions.checkState(clientId == null, "Client ID has already been set.");
+ this.clientId = id;
+ }
+
+ /**
* Requests a single chunk from the remote side, from the pre-negotiated streamId.
*
* Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though
@@ -207,6 +228,7 @@ public class TransportClient implements Closeable {
public String toString() {
return Objects.toStringHelper(this)
.add("remoteAdress", channel.remoteAddress())
+ .add("clientId", clientId)
.add("isActive", isActive())
.toString();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index 185ba2ef3b..69923769d4 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -77,6 +77,8 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
payload = saslClient.response(response);
}
+ client.setClientId(appId);
+
if (encrypt) {
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
throw new RuntimeException(
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index be6165caf3..3f2ebe3288 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -81,6 +81,7 @@ class SaslRpcHandler extends RpcHandler {
if (saslServer == null) {
// First message in the handshake, setup the necessary state.
+ client.setClientId(saslMessage.appId);
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
conf.saslServerAlwaysEncrypt());
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index c95e64e8e2..e671854da1 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -24,13 +24,13 @@ import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
+import com.google.common.base.Preconditions;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer;
-
-import com.google.common.base.Preconditions;
+import org.apache.spark.network.client.TransportClient;
/**
* StreamManager which allows registration of an Iterator&lt;ManagedBuffer&gt;, which are individually
@@ -44,6 +44,7 @@ public class OneForOneStreamManager extends StreamManager {
/** State of a single stream. */
private static class StreamState {
+ final String appId;
final Iterator<ManagedBuffer> buffers;
// The channel associated to the stream
@@ -53,7 +54,8 @@ public class OneForOneStreamManager extends StreamManager {
// that the caller only requests each chunk one at a time, in order.
int curChunk = 0;
- StreamState(Iterator<ManagedBuffer> buffers) {
+ StreamState(String appId, Iterator<ManagedBuffer> buffers) {
+ this.appId = appId;
this.buffers = Preconditions.checkNotNull(buffers);
}
}
@@ -109,15 +111,34 @@ public class OneForOneStreamManager extends StreamManager {
}
}
+ @Override
+ public void checkAuthorization(TransportClient client, long streamId) {
+ if (client.getClientId() != null) {
+ StreamState state = streams.get(streamId);
+ Preconditions.checkArgument(state != null, "Unknown stream ID.");
+ if (!client.getClientId().equals(state.appId)) {
+ throw new SecurityException(String.format(
+ "Client %s not authorized to read stream %d (app %s).",
+ client.getClientId(),
+ streamId,
+ state.appId));
+ }
+ }
+ }
+
/**
* Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
* callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a
* client connection is closed before the iterator is fully drained, then the remaining buffers
* will all be release()'d.
+ *
+ * If an app ID is provided, only callers who've authenticated with the given app ID will be
+ * allowed to fetch from this stream.
*/
- public long registerStream(Iterator<ManagedBuffer> buffers) {
+ public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
long myStreamId = nextStreamId.getAndIncrement();
- streams.put(myStreamId, new StreamState(buffers));
+ streams.put(myStreamId, new StreamState(appId, buffers));
return myStreamId;
}
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
index 929f789bf9..aaa677c965 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -20,6 +20,7 @@ package org.apache.spark.network.server;
import io.netty.channel.Channel;
import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.TransportClient;
/**
* The StreamManager is used to fetch individual chunks from a stream. This is used in
@@ -60,4 +61,12 @@ public abstract class StreamManager {
* to read from the associated streams again, so any state can be cleaned up.
*/
public void connectionTerminated(Channel channel) { }
+
+ /**
+ * Verify that the client is authorized to read from the given stream.
+ *
+ * @throws SecurityException If client is not authorized.
+ */
+ public void checkAuthorization(TransportClient client, long streamId) { }
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index e5159ab56d..df60278058 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -97,6 +97,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
ManagedBuffer buf;
try {
+ streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
streamManager.registerChannel(channel, req.streamChunkId.streamId);
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
} catch (Exception e) {
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 0df1dd621f..3ddf5c3c39 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -58,7 +58,7 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
/** Enables mocking out the StreamManager and BlockManager. */
@VisibleForTesting
- ExternalShuffleBlockHandler(
+ public ExternalShuffleBlockHandler(
OneForOneStreamManager streamManager,
ExternalShuffleBlockResolver blockManager) {
this.streamManager = streamManager;
@@ -77,17 +77,19 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
RpcResponseCallback callback) {
if (msgObj instanceof OpenBlocks) {
OpenBlocks msg = (OpenBlocks) msgObj;
- List<ManagedBuffer> blocks = Lists.newArrayList();
+ checkAuth(client, msg.appId);
+ List<ManagedBuffer> blocks = Lists.newArrayList();
for (String blockId : msg.blockIds) {
blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId));
}
- long streamId = streamManager.registerStream(blocks.iterator());
+ long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator());
logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray());
} else if (msgObj instanceof RegisterExecutor) {
RegisterExecutor msg = (RegisterExecutor) msgObj;
+ checkAuth(client, msg.appId);
blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
callback.onSuccess(new byte[0]);
@@ -126,4 +128,12 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
public void close() {
blockManager.close();
}
+
+ private void checkAuth(TransportClient client, String appId) {
+ if (client.getClientId() != null && !client.getClientId().equals(appId)) {
+ throw new SecurityException(String.format(
+ "Client for %s not authorized for application %s.", client.getClientId(), appId));
+ }
+ }
+
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index 382f613ecb..5cb0e4d4a6 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.network.sasl;
import java.io.IOException;
import java.util.Arrays;
+import java.util.concurrent.atomic.AtomicReference;
import com.google.common.collect.Lists;
import org.junit.After;
@@ -27,9 +28,12 @@ import org.junit.BeforeClass;
import org.junit.Test;
import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
import org.apache.spark.network.TestUtils;
import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
@@ -39,44 +43,39 @@ import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.shuffle.BlockFetchingListener;
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
+import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver;
+import org.apache.spark.network.shuffle.OneForOneBlockFetcher;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class SaslIntegrationSuite {
- static ExternalShuffleBlockHandler handler;
static TransportServer server;
static TransportConf conf;
static TransportContext context;
+ static SecretKeyHolder secretKeyHolder;
TransportClientFactory clientFactory;
- /** Provides a secret key holder which always returns the given secret key. */
- static class TestSecretKeyHolder implements SecretKeyHolder {
-
- private final String secretKey;
-
- TestSecretKeyHolder(String secretKey) {
- this.secretKey = secretKey;
- }
-
- @Override
- public String getSaslUser(String appId) {
- return "user";
- }
- @Override
- public String getSecretKey(String appId) {
- return secretKey;
- }
- }
-
-
@BeforeClass
public static void beforeAll() throws IOException {
- SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key");
conf = new TransportConf(new SystemPropertyConfigProvider());
context = new TransportContext(conf, new TestRpcHandler());
+ secretKeyHolder = mock(SecretKeyHolder.class);
+ when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1");
+ when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1");
+ when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2");
+ when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2");
+ when(secretKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
+ when(secretKeyHolder.getSecretKey(anyString())).thenReturn("correct-password");
+
TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
server = context.createServer(Arrays.asList(bootstrap));
}
@@ -99,7 +98,7 @@ public class SaslIntegrationSuite {
public void testGoodClient() throws IOException {
clientFactory = context.createClientFactory(
Lists.<TransportClientBootstrap>newArrayList(
- new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key"))));
+ new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
String msg = "Hello, World!";
@@ -109,13 +108,17 @@ public class SaslIntegrationSuite {
@Test
public void testBadClient() {
+ SecretKeyHolder badKeyHolder = mock(SecretKeyHolder.class);
+ when(badKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
+ when(badKeyHolder.getSecretKey(anyString())).thenReturn("wrong-password");
clientFactory = context.createClientFactory(
Lists.<TransportClientBootstrap>newArrayList(
- new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key"))));
+ new SaslClientBootstrap(conf, "unknown-app", badKeyHolder)));
try {
// Bootstrap should fail on startup.
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ fail("Connection should have failed.");
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
}
@@ -149,7 +152,7 @@ public class SaslIntegrationSuite {
TransportContext context = new TransportContext(conf, handler);
clientFactory = context.createClientFactory(
Lists.<TransportClientBootstrap>newArrayList(
- new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key"))));
+ new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
TransportServer server = context.createServer();
try {
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
@@ -160,6 +163,110 @@ public class SaslIntegrationSuite {
}
}
+ /**
+ * This test is not actually testing SASL behavior, but testing that the shuffle service
+ * performs correct authorization checks based on the SASL authentication data.
+ */
+ @Test
+ public void testAppIsolation() throws Exception {
+ // Start a new server with the correct RPC handler to serve block data.
+ ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class);
+ ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler(
+ new OneForOneStreamManager(), blockResolver);
+ TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
+ TransportContext blockServerContext = new TransportContext(conf, blockHandler);
+ TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
+
+ TransportClient client1 = null;
+ TransportClient client2 = null;
+ TransportClientFactory clientFactory2 = null;
+ try {
+ // Create a client, and make a request to fetch blocks from a different app.
+ clientFactory = blockServerContext.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
+ client1 = clientFactory.createClient(TestUtils.getLocalHost(),
+ blockServer.getPort());
+
+ final AtomicReference<Throwable> exception = new AtomicReference<>();
+
+ BlockFetchingListener listener = new BlockFetchingListener() {
+ @Override
+ public synchronized void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
+ notifyAll();
+ }
+
+ @Override
+ public synchronized void onBlockFetchFailure(String blockId, Throwable t) {
+ exception.set(t);
+ notifyAll();
+ }
+ };
+
+ String[] blockIds = new String[] { "shuffle_2_3_4", "shuffle_6_7_8" };
+ OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0",
+ blockIds, listener);
+ synchronized (listener) {
+ fetcher.start();
+ listener.wait();
+ }
+ checkSecurityException(exception.get());
+
+ // Register an executor so that the next steps work.
+ ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
+ new String[] { System.getProperty("java.io.tmpdir") }, 1,
+ "org.apache.spark.shuffle.sort.SortShuffleManager");
+ RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
+ client1.sendRpcSync(regmsg.toByteArray(), 10000);
+
+ // Make a successful request to fetch blocks, which creates a new stream. But do not actually
+ // fetch any blocks, to keep the stream open.
+ OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
+ byte[] response = client1.sendRpcSync(openMessage.toByteArray(), 10000);
+ StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);
+ long streamId = stream.streamId;
+
+ // Create a second client, authenticated with a different app ID, and try to read from
+ // the stream created for the previous app.
+ clientFactory2 = blockServerContext.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-2", secretKeyHolder)));
+ client2 = clientFactory2.createClient(TestUtils.getLocalHost(),
+ blockServer.getPort());
+
+ ChunkReceivedCallback callback = new ChunkReceivedCallback() {
+ @Override
+ public synchronized void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ notifyAll();
+ }
+
+ @Override
+ public synchronized void onFailure(int chunkIndex, Throwable t) {
+ exception.set(t);
+ notifyAll();
+ }
+ };
+
+ exception.set(null);
+ synchronized (callback) {
+ client2.fetchChunk(streamId, 0, callback);
+ callback.wait();
+ }
+ checkSecurityException(exception.get());
+ } finally {
+ if (client1 != null) {
+ client1.close();
+ }
+ if (client2 != null) {
+ client2.close();
+ }
+ if (clientFactory2 != null) {
+ clientFactory2.close();
+ }
+ blockServer.close();
+ }
+ }
+
/** RPC handler which simply responds with the message it received. */
public static class TestRpcHandler extends RpcHandler {
@Override
@@ -172,4 +279,10 @@ public class SaslIntegrationSuite {
return new OneForOneStreamManager();
}
}
+
+ private void checkSecurityException(Throwable t) {
+ assertNotNull("No exception was caught.", t);
+ assertTrue("Expected SecurityException.",
+ t.getMessage().contains(SecurityException.class.getName()));
+ }
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index 1d197497b7..e61390cf57 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -93,7 +93,7 @@ public class ExternalShuffleBlockHandlerSuite {
@SuppressWarnings("unchecked")
ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
(ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
- verify(streamManager, times(1)).registerStream(stream.capture());
+ verify(streamManager, times(1)).registerStream(anyString(), stream.capture());
Iterator<ManagedBuffer> buffers = stream.getValue();
assertEquals(block0Marker, buffers.next());
assertEquals(block1Marker, buffers.next());