aboutsummaryrefslogtreecommitdiff
path: root/network/shuffle/src/test
diff options
context:
space:
mode:
authorAaron Davidson <aaron@databricks.com>2014-11-04 16:15:38 -0800
committerReynold Xin <rxin@databricks.com>2014-11-04 16:15:38 -0800
commit5e73138a0152b78380b3f1def4b969b58e70dd11 (patch)
treed899e389f29e7c87c8f40e84872477a9b5e52277 /network/shuffle/src/test
parentf90ad5d426cb726079c490a9bb4b1100e2b4e602 (diff)
downloadspark-5e73138a0152b78380b3f1def4b969b58e70dd11.tar.gz
spark-5e73138a0152b78380b3f1def4b969b58e70dd11.tar.bz2
spark-5e73138a0152b78380b3f1def4b969b58e70dd11.zip
[SPARK-2938] Support SASL authentication in NettyBlockTransferService
Also lays the groundwork for supporting it inside the external shuffle service. Author: Aaron Davidson <aaron@databricks.com> Closes #3087 from aarondav/sasl and squashes the following commits: 3481718 [Aaron Davidson] Delete rogue println 44f8410 [Aaron Davidson] Delete documentation - muahaha! eb9f065 [Aaron Davidson] Improve documentation and add end-to-end test at Spark-level a6b95f1 [Aaron Davidson] Address comments 785bbde [Aaron Davidson] Cleanup 79973cb [Aaron Davidson] Remove unused file 151b3c5 [Aaron Davidson] Add docs, timeout config, better failure handling f6177d7 [Aaron Davidson] Cleanup SASL state upon connection termination 7b42adb [Aaron Davidson] Add unit tests 8191bcb [Aaron Davidson] [SPARK-2938] Support SASL authentication in NettyBlockTransferService
Diffstat (limited to 'network/shuffle/src/test')
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java172
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java89
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java7
3 files changed, 266 insertions, 2 deletions
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
new file mode 100644
index 0000000000..8478120786
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class SaslIntegrationSuite {
+ static ExternalShuffleBlockHandler handler;
+ static TransportServer server;
+ static TransportConf conf;
+ static TransportContext context;
+
+ TransportClientFactory clientFactory;
+
+ /** Provides a secret key holder which always returns the given secret key. */
+ static class TestSecretKeyHolder implements SecretKeyHolder {
+
+ private final String secretKey;
+
+ TestSecretKeyHolder(String secretKey) {
+ this.secretKey = secretKey;
+ }
+
+ @Override
+ public String getSaslUser(String appId) {
+ return "user";
+ }
+ @Override
+ public String getSecretKey(String appId) {
+ return secretKey;
+ }
+ }
+
+
+ @BeforeClass
+ public static void beforeAll() throws IOException {
+ SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key");
+ SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder);
+ conf = new TransportConf(new SystemPropertyConfigProvider());
+ context = new TransportContext(conf, handler);
+ server = context.createServer();
+ }
+
+
+ @AfterClass
+ public static void afterAll() {
+ server.close();
+ }
+
+ @After
+ public void afterEach() {
+ if (clientFactory != null) {
+ clientFactory.close();
+ clientFactory = null;
+ }
+ }
+
+ @Test
+ public void testGoodClient() {
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key"))));
+
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ String msg = "Hello, World!";
+ byte[] resp = client.sendRpcSync(msg.getBytes(), 1000);
+ assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg
+ }
+
+ @Test
+ public void testBadClient() {
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key"))));
+
+ try {
+ // Bootstrap should fail on startup.
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
+ }
+ }
+
+ @Test
+ public void testNoSaslClient() {
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList());
+
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ client.sendRpcSync(new byte[13], 1000);
+ fail("Should have failed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage"));
+ }
+
+ try {
+ // Guessing the right tag byte doesn't magically get you in...
+ client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000);
+ fail("Should have failed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException"));
+ }
+ }
+
+ @Test
+ public void testNoSaslServer() {
+ RpcHandler handler = new TestRpcHandler();
+ TransportContext context = new TransportContext(conf, handler);
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key"))));
+ TransportServer server = context.createServer();
+ try {
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation"));
+ } finally {
+ server.close();
+ }
+ }
+
+ /** RPC handler which simply responds with the message it received. */
+ public static class TestRpcHandler extends RpcHandler {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ callback.onSuccess(message);
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return new OneForOneStreamManager();
+ }
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
new file mode 100644
index 0000000000..67a07f38eb
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.util.Map;
+
+import com.google.common.collect.ImmutableMap;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+/**
+ * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes.
+ */
+public class SparkSaslSuite {
+
+ /** Provides a secret key holder which returns secret key == appId */
+ private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() {
+ @Override
+ public String getSaslUser(String appId) {
+ return "user";
+ }
+
+ @Override
+ public String getSecretKey(String appId) {
+ return appId;
+ }
+ };
+
+ @Test
+ public void testMatching() {
+ SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder);
+ SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder);
+
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+
+ byte[] clientMessage = client.firstToken();
+
+ while (!client.isComplete()) {
+ clientMessage = client.response(server.response(clientMessage));
+ }
+ assertTrue(server.isComplete());
+
+ // Disposal should invalidate
+ server.dispose();
+ assertFalse(server.isComplete());
+ client.dispose();
+ assertFalse(client.isComplete());
+ }
+
+
+ @Test
+ public void testNonMatching() {
+ SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder);
+ SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder);
+
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+
+ byte[] clientMessage = client.firstToken();
+
+ try {
+ while (!client.isComplete()) {
+ clientMessage = client.response(server.response(clientMessage));
+ }
+ fail("Should not have completed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage().contains("Mismatched response"));
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+ }
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index b3bcf5fd68..bc101f5384 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -135,7 +135,8 @@ public class ExternalShuffleIntegrationSuite {
final Semaphore requestsRemaining = new Semaphore(0);
- ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID);
+ ExternalShuffleClient client = new ExternalShuffleClient(conf);
+ client.init(APP_ID);
client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
new BlockFetchingListener() {
@Override
@@ -164,6 +165,7 @@ public class ExternalShuffleIntegrationSuite {
if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) {
fail("Timeout getting response from the server");
}
+ client.close();
return res;
}
@@ -265,7 +267,8 @@ public class ExternalShuffleIntegrationSuite {
}
private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) {
- ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID);
+ ExternalShuffleClient client = new ExternalShuffleClient(conf);
+ client.init(APP_ID);
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
executorId, executorInfo);
}