aboutsummaryrefslogtreecommitdiff
path: root/network/shuffle
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2015-09-02 12:53:24 -0700
committerMarcelo Vanzin <vanzin@cloudera.com>2015-09-02 12:53:24 -0700
commit2da3a9e98e5d129d4507b5db01bba5ee9558d28e (patch)
treec5197f543f18959d793db1caea4ee553acef4f97 /network/shuffle
parentfc48307797912dc1d53893dce741ddda8630957b (diff)
downloadspark-2da3a9e98e5d129d4507b5db01bba5ee9558d28e.tar.gz
spark-2da3a9e98e5d129d4507b5db01bba5ee9558d28e.tar.bz2
spark-2da3a9e98e5d129d4507b5db01bba5ee9558d28e.zip
[SPARK-10004] [SHUFFLE] Perform auth checks when clients read shuffle data.
To correctly isolate applications, when requests to read shuffle data arrive at the shuffle service, proper authorization checks need to be performed. This change makes sure that only the application that created the shuffle data can read from it. Such checks are only enabled when "spark.authenticate" is enabled, otherwise there's no secure way to make sure that the client is really who it says it is. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #8218 from vanzin/SPARK-10004.
Diffstat (limited to 'network/shuffle')
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java16
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java163
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java2
3 files changed, 152 insertions, 29 deletions
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 0df1dd621f..3ddf5c3c39 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -58,7 +58,7 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
/** Enables mocking out the StreamManager and BlockManager. */
@VisibleForTesting
- ExternalShuffleBlockHandler(
+ public ExternalShuffleBlockHandler(
OneForOneStreamManager streamManager,
ExternalShuffleBlockResolver blockManager) {
this.streamManager = streamManager;
@@ -77,17 +77,19 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
RpcResponseCallback callback) {
if (msgObj instanceof OpenBlocks) {
OpenBlocks msg = (OpenBlocks) msgObj;
- List<ManagedBuffer> blocks = Lists.newArrayList();
+ checkAuth(client, msg.appId);
+ List<ManagedBuffer> blocks = Lists.newArrayList();
for (String blockId : msg.blockIds) {
blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId));
}
- long streamId = streamManager.registerStream(blocks.iterator());
+ long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator());
logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray());
} else if (msgObj instanceof RegisterExecutor) {
RegisterExecutor msg = (RegisterExecutor) msgObj;
+ checkAuth(client, msg.appId);
blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
callback.onSuccess(new byte[0]);
@@ -126,4 +128,12 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
public void close() {
blockManager.close();
}
+
+ private void checkAuth(TransportClient client, String appId) {
+ if (client.getClientId() != null && !client.getClientId().equals(appId)) {
+ throw new SecurityException(String.format(
+ "Client for %s not authorized for application %s.", client.getClientId(), appId));
+ }
+ }
+
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index 382f613ecb..5cb0e4d4a6 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.network.sasl;
import java.io.IOException;
import java.util.Arrays;
+import java.util.concurrent.atomic.AtomicReference;
import com.google.common.collect.Lists;
import org.junit.After;
@@ -27,9 +28,12 @@ import org.junit.BeforeClass;
import org.junit.Test;
import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
import org.apache.spark.network.TestUtils;
import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
@@ -39,44 +43,39 @@ import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.shuffle.BlockFetchingListener;
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
+import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver;
+import org.apache.spark.network.shuffle.OneForOneBlockFetcher;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class SaslIntegrationSuite {
- static ExternalShuffleBlockHandler handler;
static TransportServer server;
static TransportConf conf;
static TransportContext context;
+ static SecretKeyHolder secretKeyHolder;
TransportClientFactory clientFactory;
- /** Provides a secret key holder which always returns the given secret key. */
- static class TestSecretKeyHolder implements SecretKeyHolder {
-
- private final String secretKey;
-
- TestSecretKeyHolder(String secretKey) {
- this.secretKey = secretKey;
- }
-
- @Override
- public String getSaslUser(String appId) {
- return "user";
- }
- @Override
- public String getSecretKey(String appId) {
- return secretKey;
- }
- }
-
-
@BeforeClass
public static void beforeAll() throws IOException {
- SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key");
conf = new TransportConf(new SystemPropertyConfigProvider());
context = new TransportContext(conf, new TestRpcHandler());
+ secretKeyHolder = mock(SecretKeyHolder.class);
+ when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1");
+ when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1");
+ when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2");
+ when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2");
+ when(secretKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
+ when(secretKeyHolder.getSecretKey(anyString())).thenReturn("correct-password");
+
TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
server = context.createServer(Arrays.asList(bootstrap));
}
@@ -99,7 +98,7 @@ public class SaslIntegrationSuite {
public void testGoodClient() throws IOException {
clientFactory = context.createClientFactory(
Lists.<TransportClientBootstrap>newArrayList(
- new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key"))));
+ new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
String msg = "Hello, World!";
@@ -109,13 +108,17 @@ public class SaslIntegrationSuite {
@Test
public void testBadClient() {
+ SecretKeyHolder badKeyHolder = mock(SecretKeyHolder.class);
+ when(badKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
+ when(badKeyHolder.getSecretKey(anyString())).thenReturn("wrong-password");
clientFactory = context.createClientFactory(
Lists.<TransportClientBootstrap>newArrayList(
- new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key"))));
+ new SaslClientBootstrap(conf, "unknown-app", badKeyHolder)));
try {
// Bootstrap should fail on startup.
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ fail("Connection should have failed.");
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
}
@@ -149,7 +152,7 @@ public class SaslIntegrationSuite {
TransportContext context = new TransportContext(conf, handler);
clientFactory = context.createClientFactory(
Lists.<TransportClientBootstrap>newArrayList(
- new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key"))));
+ new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
TransportServer server = context.createServer();
try {
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
@@ -160,6 +163,110 @@ public class SaslIntegrationSuite {
}
}
+ /**
+ * This test is not actually testing SASL behavior, but testing that the shuffle service
+ * performs correct authorization checks based on the SASL authentication data.
+ */
+ @Test
+ public void testAppIsolation() throws Exception {
+ // Start a new server with the correct RPC handler to serve block data.
+ ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class);
+ ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler(
+ new OneForOneStreamManager(), blockResolver);
+ TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
+ TransportContext blockServerContext = new TransportContext(conf, blockHandler);
+ TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
+
+ TransportClient client1 = null;
+ TransportClient client2 = null;
+ TransportClientFactory clientFactory2 = null;
+ try {
+ // Create a client, and make a request to fetch blocks from a different app.
+ clientFactory = blockServerContext.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
+ client1 = clientFactory.createClient(TestUtils.getLocalHost(),
+ blockServer.getPort());
+
+ final AtomicReference<Throwable> exception = new AtomicReference<>();
+
+ BlockFetchingListener listener = new BlockFetchingListener() {
+ @Override
+ public synchronized void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
+ notifyAll();
+ }
+
+ @Override
+ public synchronized void onBlockFetchFailure(String blockId, Throwable t) {
+ exception.set(t);
+ notifyAll();
+ }
+ };
+
+ String[] blockIds = new String[] { "shuffle_2_3_4", "shuffle_6_7_8" };
+ OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0",
+ blockIds, listener);
+ synchronized (listener) {
+ fetcher.start();
+ listener.wait();
+ }
+ checkSecurityException(exception.get());
+
+ // Register an executor so that the next steps work.
+ ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
+ new String[] { System.getProperty("java.io.tmpdir") }, 1,
+ "org.apache.spark.shuffle.sort.SortShuffleManager");
+ RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
+ client1.sendRpcSync(regmsg.toByteArray(), 10000);
+
+ // Make a successful request to fetch blocks, which creates a new stream. But do not actually
+ // fetch any blocks, to keep the stream open.
+ OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
+ byte[] response = client1.sendRpcSync(openMessage.toByteArray(), 10000);
+ StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);
+ long streamId = stream.streamId;
+
+ // Create a second client, authenticated with a different app ID, and try to read from
+ // the stream created for the previous app.
+ clientFactory2 = blockServerContext.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-2", secretKeyHolder)));
+ client2 = clientFactory2.createClient(TestUtils.getLocalHost(),
+ blockServer.getPort());
+
+ ChunkReceivedCallback callback = new ChunkReceivedCallback() {
+ @Override
+ public synchronized void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ notifyAll();
+ }
+
+ @Override
+ public synchronized void onFailure(int chunkIndex, Throwable t) {
+ exception.set(t);
+ notifyAll();
+ }
+ };
+
+ exception.set(null);
+ synchronized (callback) {
+ client2.fetchChunk(streamId, 0, callback);
+ callback.wait();
+ }
+ checkSecurityException(exception.get());
+ } finally {
+ if (client1 != null) {
+ client1.close();
+ }
+ if (client2 != null) {
+ client2.close();
+ }
+ if (clientFactory2 != null) {
+ clientFactory2.close();
+ }
+ blockServer.close();
+ }
+ }
+
/** RPC handler which simply responds with the message it received. */
public static class TestRpcHandler extends RpcHandler {
@Override
@@ -172,4 +279,10 @@ public class SaslIntegrationSuite {
return new OneForOneStreamManager();
}
}
+
+ private void checkSecurityException(Throwable t) {
+ assertNotNull("No exception was caught.", t);
+ assertTrue("Expected SecurityException.",
+ t.getMessage().contains(SecurityException.class.getName()));
+ }
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index 1d197497b7..e61390cf57 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -93,7 +93,7 @@ public class ExternalShuffleBlockHandlerSuite {
@SuppressWarnings("unchecked")
ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
(ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
- verify(streamManager, times(1)).registerStream(stream.capture());
+ verify(streamManager, times(1)).registerStream(anyString(), stream.capture());
Iterator<ManagedBuffer> buffers = stream.getValue();
assertEquals(block0Marker, buffers.next());
assertEquals(block1Marker, buffers.next());