aboutsummaryrefslogtreecommitdiff
path: root/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
diff options
context:
space:
mode:
Diffstat (limited to 'common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java')
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java97
1 files changed, 13 insertions, 84 deletions
diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index e27301f49e..87129b900b 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -56,7 +56,6 @@ import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
-import org.apache.spark.network.sasl.aes.AesCipher;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
@@ -153,7 +152,7 @@ public class SparkSaslSuite {
.when(rpcHandler)
.receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class));
- SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false, false);
+ SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
try {
ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
TimeUnit.SECONDS.toMillis(10));
@@ -279,7 +278,7 @@ public class SparkSaslSuite {
new Random().nextBytes(data);
Files.write(data, file);
- ctx = new SaslTestCtx(rpcHandler, true, false, false, testConf);
+ ctx = new SaslTestCtx(rpcHandler, true, false, testConf);
final CountDownLatch lock = new CountDownLatch(1);
@@ -317,7 +316,7 @@ public class SparkSaslSuite {
public void testServerAlwaysEncrypt() throws Exception {
SaslTestCtx ctx = null;
try {
- ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, false,
+ ctx = new SaslTestCtx(mock(RpcHandler.class), false, false,
ImmutableMap.of("spark.network.sasl.serverAlwaysEncrypt", "true"));
fail("Should have failed to connect without encryption.");
} catch (Exception e) {
@@ -336,7 +335,7 @@ public class SparkSaslSuite {
// able to understand RPCs sent to it and thus close the connection.
SaslTestCtx ctx = null;
try {
- ctx = new SaslTestCtx(mock(RpcHandler.class), true, true, false);
+ ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
TimeUnit.SECONDS.toMillis(10));
fail("Should have failed to send RPC to server.");
@@ -374,69 +373,6 @@ public class SparkSaslSuite {
}
}
- @Test
- public void testAesEncryption() throws Exception {
- final AtomicReference<ManagedBuffer> response = new AtomicReference<>();
- final File file = File.createTempFile("sasltest", ".txt");
- SaslTestCtx ctx = null;
- try {
- final TransportConf conf = new TransportConf("rpc", MapConfigProvider.EMPTY);
- final TransportConf spyConf = spy(conf);
- doReturn(true).when(spyConf).aesEncryptionEnabled();
-
- StreamManager sm = mock(StreamManager.class);
- when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer<ManagedBuffer>() {
- @Override
- public ManagedBuffer answer(InvocationOnMock invocation) {
- return new FileSegmentManagedBuffer(spyConf, file, 0, file.length());
- }
- });
-
- RpcHandler rpcHandler = mock(RpcHandler.class);
- when(rpcHandler.getStreamManager()).thenReturn(sm);
-
- byte[] data = new byte[256 * 1024 * 1024];
- new Random().nextBytes(data);
- Files.write(data, file);
-
- ctx = new SaslTestCtx(rpcHandler, true, false, true);
-
- final Object lock = new Object();
-
- ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
- doAnswer(new Answer<Void>() {
- @Override
- public Void answer(InvocationOnMock invocation) {
- response.set((ManagedBuffer) invocation.getArguments()[1]);
- response.get().retain();
- synchronized (lock) {
- lock.notifyAll();
- }
- return null;
- }
- }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
-
- synchronized (lock) {
- ctx.client.fetchChunk(0, 0, callback);
- lock.wait(10 * 1000);
- }
-
- verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
- verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
-
- byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
- assertTrue(Arrays.equals(data, received));
- } finally {
- file.delete();
- if (ctx != null) {
- ctx.close();
- }
- if (response.get() != null) {
- response.get().release();
- }
- }
- }
-
private static class SaslTestCtx {
final TransportClient client;
@@ -449,46 +385,39 @@ public class SparkSaslSuite {
SaslTestCtx(
RpcHandler rpcHandler,
boolean encrypt,
- boolean disableClientEncryption,
- boolean aesEnable)
+ boolean disableClientEncryption)
throws Exception {
- this(rpcHandler, encrypt, disableClientEncryption, aesEnable,
- Collections.<String, String>emptyMap());
+ this(rpcHandler, encrypt, disableClientEncryption, Collections.<String, String>emptyMap());
}
SaslTestCtx(
RpcHandler rpcHandler,
boolean encrypt,
boolean disableClientEncryption,
- boolean aesEnable,
- Map<String, String> testConf)
+ Map<String, String> extraConf)
throws Exception {
+ Map<String, String> testConf = ImmutableMap.<String, String>builder()
+ .putAll(extraConf)
+ .put("spark.authenticate.enableSaslEncryption", String.valueOf(encrypt))
+ .build();
TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf));
- if (aesEnable) {
- conf = spy(conf);
- doReturn(true).when(conf).aesEncryptionEnabled();
- }
-
SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
when(keyHolder.getSaslUser(anyString())).thenReturn("user");
when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
TransportContext ctx = new TransportContext(conf, rpcHandler);
- String encryptHandlerName = aesEnable ? AesCipher.ENCRYPTION_HANDLER_NAME :
- SaslEncryption.ENCRYPTION_HANDLER_NAME;
-
- this.checker = new EncryptionCheckerBootstrap(encryptHandlerName);
+ this.checker = new EncryptionCheckerBootstrap(SaslEncryption.ENCRYPTION_HANDLER_NAME);
this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder),
checker));
try {
List<TransportClientBootstrap> clientBootstraps = Lists.newArrayList();
- clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt));
+ clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder));
if (disableClientEncryption) {
clientBootstraps.add(new EncryptionDisablerBootstrap());
}