aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala2
-rw-r--r--network/common/pom.xml4
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java22
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java1
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java31
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/StreamManager.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java1
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java16
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java163
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java2
-rw-r--r--project/MimaExcludes.scala1
13 files changed, 221 insertions, 36 deletions
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index 7c170a742f..76968249fb 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -38,6 +38,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel}
* is equivalent to one Spark-level shuffle block.
*/
class NettyBlockRpcServer(
+ appId: String,
serializer: Serializer,
blockManager: BlockDataManager)
extends RpcHandler with Logging {
@@ -55,7 +56,7 @@ class NettyBlockRpcServer(
case openBlocks: OpenBlocks =>
val blocks: Seq[ManagedBuffer] =
openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
- val streamId = streamManager.registerStream(blocks.iterator.asJava)
+ val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index ff8aae9ebe..d5ad2c9ad0 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -49,7 +49,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
private[this] var appId: String = _
override def init(blockDataManager: BlockDataManager): Unit = {
- val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
+ val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager)
var serverBootstrap: Option[TransportServerBootstrap] = None
var clientBootstrap: Option[TransportClientBootstrap] = None
if (authEnabled) {
diff --git a/network/common/pom.xml b/network/common/pom.xml
index 7dc3068ab8..4141fcb826 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -48,6 +48,10 @@
<artifactId>slf4j-api</artifactId>
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>com.google.code.findbugs</groupId>
+ <artifactId>jsr305</artifactId>
+ </dependency>
<!--
Promote Guava to "compile" so that maven-shade-plugin picks it up (for packaging the Optional
class exposed in the Java API). The plugin will then remove this dependency from the published
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index e8e7f06247..df841288a0 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -23,6 +23,7 @@ import java.net.SocketAddress;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
@@ -70,6 +71,7 @@ public class TransportClient implements Closeable {
private final Channel channel;
private final TransportResponseHandler handler;
+ @Nullable private String clientId;
public TransportClient(Channel channel, TransportResponseHandler handler) {
this.channel = Preconditions.checkNotNull(channel);
@@ -85,6 +87,25 @@ public class TransportClient implements Closeable {
}
/**
+ * Returns the ID used by the client to authenticate itself when authentication is enabled.
+ *
+ * @return The client ID, or null if authentication is disabled.
+ */
+ public String getClientId() {
+ return clientId;
+ }
+
+ /**
+ * Sets the authenticated client ID. This is meant to be used by the authentication layer.
+ *
+ * Trying to set a different client ID after it's been set will result in an exception.
+ */
+ public void setClientId(String id) {
+ Preconditions.checkState(clientId == null, "Client ID has already been set.");
+ this.clientId = id;
+ }
+
+ /**
* Requests a single chunk from the remote side, from the pre-negotiated streamId.
*
* Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though
@@ -207,6 +228,7 @@ public class TransportClient implements Closeable {
public String toString() {
return Objects.toStringHelper(this)
.add("remoteAdress", channel.remoteAddress())
+ .add("clientId", clientId)
.add("isActive", isActive())
.toString();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index 185ba2ef3b..69923769d4 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -77,6 +77,8 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
payload = saslClient.response(response);
}
+ client.setClientId(appId);
+
if (encrypt) {
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
throw new RuntimeException(
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index be6165caf3..3f2ebe3288 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -81,6 +81,7 @@ class SaslRpcHandler extends RpcHandler {
if (saslServer == null) {
// First message in the handshake, setup the necessary state.
+ client.setClientId(saslMessage.appId);
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
conf.saslServerAlwaysEncrypt());
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index c95e64e8e2..e671854da1 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -24,13 +24,13 @@ import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
+import com.google.common.base.Preconditions;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer;
-
-import com.google.common.base.Preconditions;
+import org.apache.spark.network.client.TransportClient;
/**
* StreamManager which allows registration of an Iterator&lt;ManagedBuffer&gt;, which are individually
@@ -44,6 +44,7 @@ public class OneForOneStreamManager extends StreamManager {
/** State of a single stream. */
private static class StreamState {
+ final String appId;
final Iterator<ManagedBuffer> buffers;
// The channel associated to the stream
@@ -53,7 +54,8 @@ public class OneForOneStreamManager extends StreamManager {
// that the caller only requests each chunk one at a time, in order.
int curChunk = 0;
- StreamState(Iterator<ManagedBuffer> buffers) {
+ StreamState(String appId, Iterator<ManagedBuffer> buffers) {
+ this.appId = appId;
this.buffers = Preconditions.checkNotNull(buffers);
}
}
@@ -109,15 +111,34 @@ public class OneForOneStreamManager extends StreamManager {
}
}
+ @Override
+ public void checkAuthorization(TransportClient client, long streamId) {
+ if (client.getClientId() != null) {
+ StreamState state = streams.get(streamId);
+ Preconditions.checkArgument(state != null, "Unknown stream ID.");
+ if (!client.getClientId().equals(state.appId)) {
+ throw new SecurityException(String.format(
+ "Client %s not authorized to read stream %d (app %s).",
+ client.getClientId(),
+ streamId,
+ state.appId));
+ }
+ }
+ }
+
/**
* Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
* callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a
* client connection is closed before the iterator is fully drained, then the remaining buffers
* will all be release()'d.
+ *
+ * If an app ID is provided, only callers who've authenticated with the given app ID will be
+ * allowed to fetch from this stream.
*/
- public long registerStream(Iterator<ManagedBuffer> buffers) {
+ public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
long myStreamId = nextStreamId.getAndIncrement();
- streams.put(myStreamId, new StreamState(buffers));
+ streams.put(myStreamId, new StreamState(appId, buffers));
return myStreamId;
}
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
index 929f789bf9..aaa677c965 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -20,6 +20,7 @@ package org.apache.spark.network.server;
import io.netty.channel.Channel;
import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.TransportClient;
/**
* The StreamManager is used to fetch individual chunks from a stream. This is used in
@@ -60,4 +61,12 @@ public abstract class StreamManager {
* to read from the associated streams again, so any state can be cleaned up.
*/
public void connectionTerminated(Channel channel) { }
+
+ /**
+ * Verify that the client is authorized to read from the given stream.
+ *
+ * @throws SecurityException If client is not authorized.
+ */
+ public void checkAuthorization(TransportClient client, long streamId) { }
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index e5159ab56d..df60278058 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -97,6 +97,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
ManagedBuffer buf;
try {
+ streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
streamManager.registerChannel(channel, req.streamChunkId.streamId);
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
} catch (Exception e) {
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 0df1dd621f..3ddf5c3c39 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -58,7 +58,7 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
/** Enables mocking out the StreamManager and BlockManager. */
@VisibleForTesting
- ExternalShuffleBlockHandler(
+ public ExternalShuffleBlockHandler(
OneForOneStreamManager streamManager,
ExternalShuffleBlockResolver blockManager) {
this.streamManager = streamManager;
@@ -77,17 +77,19 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
RpcResponseCallback callback) {
if (msgObj instanceof OpenBlocks) {
OpenBlocks msg = (OpenBlocks) msgObj;
- List<ManagedBuffer> blocks = Lists.newArrayList();
+ checkAuth(client, msg.appId);
+ List<ManagedBuffer> blocks = Lists.newArrayList();
for (String blockId : msg.blockIds) {
blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId));
}
- long streamId = streamManager.registerStream(blocks.iterator());
+ long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator());
logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray());
} else if (msgObj instanceof RegisterExecutor) {
RegisterExecutor msg = (RegisterExecutor) msgObj;
+ checkAuth(client, msg.appId);
blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
callback.onSuccess(new byte[0]);
@@ -126,4 +128,12 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
public void close() {
blockManager.close();
}
+
+ private void checkAuth(TransportClient client, String appId) {
+ if (client.getClientId() != null && !client.getClientId().equals(appId)) {
+ throw new SecurityException(String.format(
+ "Client for %s not authorized for application %s.", client.getClientId(), appId));
+ }
+ }
+
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index 382f613ecb..5cb0e4d4a6 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.network.sasl;
import java.io.IOException;
import java.util.Arrays;
+import java.util.concurrent.atomic.AtomicReference;
import com.google.common.collect.Lists;
import org.junit.After;
@@ -27,9 +28,12 @@ import org.junit.BeforeClass;
import org.junit.Test;
import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
import org.apache.spark.network.TestUtils;
import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
@@ -39,44 +43,39 @@ import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.shuffle.BlockFetchingListener;
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
+import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver;
+import org.apache.spark.network.shuffle.OneForOneBlockFetcher;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class SaslIntegrationSuite {
- static ExternalShuffleBlockHandler handler;
static TransportServer server;
static TransportConf conf;
static TransportContext context;
+ static SecretKeyHolder secretKeyHolder;
TransportClientFactory clientFactory;
- /** Provides a secret key holder which always returns the given secret key. */
- static class TestSecretKeyHolder implements SecretKeyHolder {
-
- private final String secretKey;
-
- TestSecretKeyHolder(String secretKey) {
- this.secretKey = secretKey;
- }
-
- @Override
- public String getSaslUser(String appId) {
- return "user";
- }
- @Override
- public String getSecretKey(String appId) {
- return secretKey;
- }
- }
-
-
@BeforeClass
public static void beforeAll() throws IOException {
- SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key");
conf = new TransportConf(new SystemPropertyConfigProvider());
context = new TransportContext(conf, new TestRpcHandler());
+ secretKeyHolder = mock(SecretKeyHolder.class);
+ when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1");
+ when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1");
+ when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2");
+ when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2");
+ when(secretKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
+ when(secretKeyHolder.getSecretKey(anyString())).thenReturn("correct-password");
+
TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
server = context.createServer(Arrays.asList(bootstrap));
}
@@ -99,7 +98,7 @@ public class SaslIntegrationSuite {
public void testGoodClient() throws IOException {
clientFactory = context.createClientFactory(
Lists.<TransportClientBootstrap>newArrayList(
- new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key"))));
+ new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
String msg = "Hello, World!";
@@ -109,13 +108,17 @@ public class SaslIntegrationSuite {
@Test
public void testBadClient() {
+ SecretKeyHolder badKeyHolder = mock(SecretKeyHolder.class);
+ when(badKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
+ when(badKeyHolder.getSecretKey(anyString())).thenReturn("wrong-password");
clientFactory = context.createClientFactory(
Lists.<TransportClientBootstrap>newArrayList(
- new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key"))));
+ new SaslClientBootstrap(conf, "unknown-app", badKeyHolder)));
try {
// Bootstrap should fail on startup.
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ fail("Connection should have failed.");
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
}
@@ -149,7 +152,7 @@ public class SaslIntegrationSuite {
TransportContext context = new TransportContext(conf, handler);
clientFactory = context.createClientFactory(
Lists.<TransportClientBootstrap>newArrayList(
- new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key"))));
+ new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
TransportServer server = context.createServer();
try {
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
@@ -160,6 +163,110 @@ public class SaslIntegrationSuite {
}
}
+ /**
+ * This test is not actually testing SASL behavior, but testing that the shuffle service
+ * performs correct authorization checks based on the SASL authentication data.
+ */
+ @Test
+ public void testAppIsolation() throws Exception {
+ // Start a new server with the correct RPC handler to serve block data.
+ ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class);
+ ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler(
+ new OneForOneStreamManager(), blockResolver);
+ TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
+ TransportContext blockServerContext = new TransportContext(conf, blockHandler);
+ TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
+
+ TransportClient client1 = null;
+ TransportClient client2 = null;
+ TransportClientFactory clientFactory2 = null;
+ try {
+ // Create a client, and make a request to fetch blocks from a different app.
+ clientFactory = blockServerContext.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
+ client1 = clientFactory.createClient(TestUtils.getLocalHost(),
+ blockServer.getPort());
+
+ final AtomicReference<Throwable> exception = new AtomicReference<>();
+
+ BlockFetchingListener listener = new BlockFetchingListener() {
+ @Override
+ public synchronized void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
+ notifyAll();
+ }
+
+ @Override
+ public synchronized void onBlockFetchFailure(String blockId, Throwable t) {
+ exception.set(t);
+ notifyAll();
+ }
+ };
+
+ String[] blockIds = new String[] { "shuffle_2_3_4", "shuffle_6_7_8" };
+ OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0",
+ blockIds, listener);
+ synchronized (listener) {
+ fetcher.start();
+ listener.wait();
+ }
+ checkSecurityException(exception.get());
+
+ // Register an executor so that the next steps work.
+ ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
+ new String[] { System.getProperty("java.io.tmpdir") }, 1,
+ "org.apache.spark.shuffle.sort.SortShuffleManager");
+ RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
+ client1.sendRpcSync(regmsg.toByteArray(), 10000);
+
+ // Make a successful request to fetch blocks, which creates a new stream. But do not actually
+ // fetch any blocks, to keep the stream open.
+ OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
+ byte[] response = client1.sendRpcSync(openMessage.toByteArray(), 10000);
+ StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);
+ long streamId = stream.streamId;
+
+ // Create a second client, authenticated with a different app ID, and try to read from
+ // the stream created for the previous app.
+ clientFactory2 = blockServerContext.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-2", secretKeyHolder)));
+ client2 = clientFactory2.createClient(TestUtils.getLocalHost(),
+ blockServer.getPort());
+
+ ChunkReceivedCallback callback = new ChunkReceivedCallback() {
+ @Override
+ public synchronized void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ notifyAll();
+ }
+
+ @Override
+ public synchronized void onFailure(int chunkIndex, Throwable t) {
+ exception.set(t);
+ notifyAll();
+ }
+ };
+
+ exception.set(null);
+ synchronized (callback) {
+ client2.fetchChunk(streamId, 0, callback);
+ callback.wait();
+ }
+ checkSecurityException(exception.get());
+ } finally {
+ if (client1 != null) {
+ client1.close();
+ }
+ if (client2 != null) {
+ client2.close();
+ }
+ if (clientFactory2 != null) {
+ clientFactory2.close();
+ }
+ blockServer.close();
+ }
+ }
+
/** RPC handler which simply responds with the message it received. */
public static class TestRpcHandler extends RpcHandler {
@Override
@@ -172,4 +279,10 @@ public class SaslIntegrationSuite {
return new OneForOneStreamManager();
}
}
+
+ private void checkSecurityException(Throwable t) {
+ assertNotNull("No exception was caught.", t);
+ assertTrue("Expected SecurityException.",
+ t.getMessage().contains(SecurityException.class.getName()));
+ }
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index 1d197497b7..e61390cf57 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -93,7 +93,7 @@ public class ExternalShuffleBlockHandlerSuite {
@SuppressWarnings("unchecked")
ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
(ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
- verify(streamManager, times(1)).registerStream(stream.capture());
+ verify(streamManager, times(1)).registerStream(anyString(), stream.capture());
Iterator<ManagedBuffer> buffers = stream.getValue();
assertEquals(block0Marker, buffers.next());
assertEquals(block1Marker, buffers.next());
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 88745dc086..714ce3cd9b 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -37,6 +37,7 @@ object MimaExcludes {
case v if v.startsWith("1.5") =>
Seq(
MimaBuild.excludeSparkPackage("deploy"),
+ MimaBuild.excludeSparkPackage("network"),
// These are needed if checking against the sbt build, since they are part of
// the maven-generated artifacts in 1.3.
excludePackage("org.spark-project.jetty"),