diff options
Diffstat (limited to 'common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java')
-rw-r--r-- | common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java | 97 |
1 files changed, 13 insertions, 84 deletions
diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index e27301f49e..87129b900b 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -56,7 +56,6 @@ import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.sasl.aes.AesCipher; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; @@ -153,7 +152,7 @@ public class SparkSaslSuite { .when(rpcHandler) .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); - SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false, false); + SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); try { ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10)); @@ -279,7 +278,7 @@ public class SparkSaslSuite { new Random().nextBytes(data); Files.write(data, file); - ctx = new SaslTestCtx(rpcHandler, true, false, false, testConf); + ctx = new SaslTestCtx(rpcHandler, true, false, testConf); final CountDownLatch lock = new CountDownLatch(1); @@ -317,7 +316,7 @@ public class SparkSaslSuite { public void testServerAlwaysEncrypt() throws Exception { SaslTestCtx ctx = null; try { - ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, false, + ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, ImmutableMap.of("spark.network.sasl.serverAlwaysEncrypt", "true")); fail("Should have failed to connect without encryption."); } catch (Exception e) { @@ -336,7 +335,7 @@ public class SparkSaslSuite { // able to understand RPCs sent to it and thus close the connection. SaslTestCtx ctx = null; try { - ctx = new SaslTestCtx(mock(RpcHandler.class), true, true, false); + ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10)); fail("Should have failed to send RPC to server."); @@ -374,69 +373,6 @@ public class SparkSaslSuite { } } - @Test - public void testAesEncryption() throws Exception { - final AtomicReference<ManagedBuffer> response = new AtomicReference<>(); - final File file = File.createTempFile("sasltest", ".txt"); - SaslTestCtx ctx = null; - try { - final TransportConf conf = new TransportConf("rpc", MapConfigProvider.EMPTY); - final TransportConf spyConf = spy(conf); - doReturn(true).when(spyConf).aesEncryptionEnabled(); - - StreamManager sm = mock(StreamManager.class); - when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer<ManagedBuffer>() { - @Override - public ManagedBuffer answer(InvocationOnMock invocation) { - return new FileSegmentManagedBuffer(spyConf, file, 0, file.length()); - } - }); - - RpcHandler rpcHandler = mock(RpcHandler.class); - when(rpcHandler.getStreamManager()).thenReturn(sm); - - byte[] data = new byte[256 * 1024 * 1024]; - new Random().nextBytes(data); - Files.write(data, file); - - ctx = new SaslTestCtx(rpcHandler, true, false, true); - - final Object lock = new Object(); - - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - doAnswer(new Answer<Void>() { - @Override - public Void answer(InvocationOnMock invocation) { - response.set((ManagedBuffer) invocation.getArguments()[1]); - response.get().retain(); - synchronized (lock) { - lock.notifyAll(); - } - return null; - } - }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); - - synchronized (lock) { - ctx.client.fetchChunk(0, 0, callback); - lock.wait(10 * 1000); - } - - verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class)); - verify(callback, never()).onFailure(anyInt(), any(Throwable.class)); - - byte[] received = ByteStreams.toByteArray(response.get().createInputStream()); - assertTrue(Arrays.equals(data, received)); - } finally { - file.delete(); - if (ctx != null) { - ctx.close(); - } - if (response.get() != null) { - response.get().release(); - } - } - } - private static class SaslTestCtx { final TransportClient client; @@ -449,46 +385,39 @@ public class SparkSaslSuite { SaslTestCtx( RpcHandler rpcHandler, boolean encrypt, - boolean disableClientEncryption, - boolean aesEnable) + boolean disableClientEncryption) throws Exception { - this(rpcHandler, encrypt, disableClientEncryption, aesEnable, - Collections.<String, String>emptyMap()); + this(rpcHandler, encrypt, disableClientEncryption, Collections.<String, String>emptyMap()); } SaslTestCtx( RpcHandler rpcHandler, boolean encrypt, boolean disableClientEncryption, - boolean aesEnable, - Map<String, String> testConf) + Map<String, String> extraConf) throws Exception { + Map<String, String> testConf = ImmutableMap.<String, String>builder() + .putAll(extraConf) + .put("spark.authenticate.enableSaslEncryption", String.valueOf(encrypt)) + .build(); TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf)); - if (aesEnable) { - conf = spy(conf); - doReturn(true).when(conf).aesEncryptionEnabled(); - } - SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); when(keyHolder.getSaslUser(anyString())).thenReturn("user"); when(keyHolder.getSecretKey(anyString())).thenReturn("secret"); TransportContext ctx = new TransportContext(conf, rpcHandler); - String encryptHandlerName = aesEnable ? AesCipher.ENCRYPTION_HANDLER_NAME : - SaslEncryption.ENCRYPTION_HANDLER_NAME; - - this.checker = new EncryptionCheckerBootstrap(encryptHandlerName); + this.checker = new EncryptionCheckerBootstrap(SaslEncryption.ENCRYPTION_HANDLER_NAME); this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder), checker)); try { List<TransportClientBootstrap> clientBootstraps = Lists.newArrayList(); - clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt)); + clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder)); if (disableClientEncryption) { clientBootstraps.add(new EncryptionDisablerBootstrap()); } |