aboutsummaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2017-01-24 10:44:04 -0800
committerShixiong Zhu <shixiong@databricks.com>2017-01-24 10:44:04 -0800
commit8f3f73abc1fe62496722476460c174af0250e3fe (patch)
tree345e96eab2294792a867d6009cc9209d6ec0b27f /common
parentd9783380ff0a6440117348dee3205826d0f9687e (diff)
downloadspark-8f3f73abc1fe62496722476460c174af0250e3fe.tar.gz
spark-8f3f73abc1fe62496722476460c174af0250e3fe.tar.bz2
spark-8f3f73abc1fe62496722476460c174af0250e3fe.zip
[SPARK-19139][CORE] New auth mechanism for transport library.
This change introduces a new auth mechanism to the transport library, to be used when users enable strong encryption. This auth mechanism has better security than the currently used DIGEST-MD5. The new protocol uses symmetric key encryption to mutually authenticate the endpoints, and is very loosely based on ISO/IEC 9798. The new protocol falls back to SASL when it thinks the remote end is old. Because SASL does not support asking the server for multiple auth protocols, which would mean we could re-use the existing SASL code by just adding a new SASL provider, the protocol is implemented outside of the SASL API to avoid the boilerplate of adding a new provider. Details of the auth protocol are discussed in the included README.md file. This change partly undos the changes added in SPARK-13331; AES encryption is now decoupled from SASL authentication. The encryption code itself, though, has been re-used as part of this change. ## How was this patch tested? - Unit tests - Tested Spark 2.2 against Spark 1.6 shuffle service with SASL enabled - Tested Spark 2.2 against Spark 2.2 shuffle service with SASL fallback disabled Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #16521 from vanzin/SPARK-19139.
Diffstat (limited to 'common')
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java128
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java284
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java170
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java55
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java101
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/README.md158
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java85
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java (renamed from common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java)138
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java36
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java41
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java101
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java92
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java109
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java213
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java80
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java97
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java19
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java5
-rw-r--r--common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java4
-rw-r--r--common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java9
-rw-r--r--common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java4
21 files changed, 1557 insertions, 372 deletions
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java
new file mode 100644
index 0000000000..980525dbf0
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.security.GeneralSecurityException;
+import java.security.Key;
+import javax.crypto.KeyGenerator;
+import javax.crypto.Mac;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.sasl.SaslClientBootstrap;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Bootstraps a {@link TransportClient} by performing authentication using Spark's auth protocol.
+ *
+ * This bootstrap falls back to using the SASL bootstrap if the server throws an error during
+ * authentication, and the configuration allows it. This is used for backwards compatibility
+ * with external shuffle services that do not support the new protocol.
+ *
+ * It also automatically falls back to SASL if the new encryption backend is disabled, so that
+ * callers only need to install this bootstrap when authentication is enabled.
+ */
+public class AuthClientBootstrap implements TransportClientBootstrap {
+
+ private static final Logger LOG = LoggerFactory.getLogger(AuthClientBootstrap.class);
+
+ private final TransportConf conf;
+ private final String appId;
+ private final String authUser;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public AuthClientBootstrap(
+ TransportConf conf,
+ String appId,
+ SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ // TODO: right now this behaves like the SASL backend, because when executors start up
+ // they don't necessarily know the app ID. So they send a hardcoded "user" that is defined
+ // in the SecurityManager, which will also always return the same secret (regardless of the
+ // user name). All that's needed here is for this "user" to match on both sides, since that's
+ // required by the protocol. At some point, though, it would be better for the actual app ID
+ // to be provided here.
+ this.appId = appId;
+ this.authUser = secretKeyHolder.getSaslUser(appId);
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ @Override
+ public void doBootstrap(TransportClient client, Channel channel) {
+ if (!conf.encryptionEnabled()) {
+ LOG.debug("AES encryption disabled, using old auth protocol.");
+ doSaslAuth(client, channel);
+ return;
+ }
+
+ try {
+ doSparkAuth(client, channel);
+ } catch (GeneralSecurityException | IOException e) {
+ throw Throwables.propagate(e);
+ } catch (RuntimeException e) {
+ // There isn't a good exception that can be caught here to know whether it's really
+ // OK to switch back to SASL (because the server doesn't speak the new protocol). So
+ // try it anyway, and in the worst case things will fail again.
+ if (conf.saslFallback()) {
+ LOG.warn("New auth protocol failed, trying SASL.", e);
+ doSaslAuth(client, channel);
+ } else {
+ throw e;
+ }
+ }
+ }
+
+ private void doSparkAuth(TransportClient client, Channel channel)
+ throws GeneralSecurityException, IOException {
+
+ AuthEngine engine = new AuthEngine(authUser, secretKeyHolder.getSecretKey(authUser), conf);
+ try {
+ ClientChallenge challenge = engine.challenge();
+ ByteBuf challengeData = Unpooled.buffer(challenge.encodedLength());
+ challenge.encode(challengeData);
+
+ ByteBuffer responseData = client.sendRpcSync(challengeData.nioBuffer(),
+ conf.authRTTimeoutMs());
+ ServerResponse response = ServerResponse.decodeMessage(responseData);
+
+ engine.validate(response);
+ engine.sessionCipher().addToChannel(channel);
+ } finally {
+ engine.close();
+ }
+ }
+
+ private void doSaslAuth(TransportClient client, Channel channel) {
+ SaslClientBootstrap sasl = new SaslClientBootstrap(conf, appId, secretKeyHolder);
+ sasl.doBootstrap(client, channel);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
new file mode 100644
index 0000000000..b769ebeba3
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
@@ -0,0 +1,284 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.math.BigInteger;
+import java.security.GeneralSecurityException;
+import java.util.Arrays;
+import java.util.Properties;
+import javax.crypto.Cipher;
+import javax.crypto.SecretKey;
+import javax.crypto.SecretKeyFactory;
+import javax.crypto.ShortBufferException;
+import javax.crypto.spec.IvParameterSpec;
+import javax.crypto.spec.PBEKeySpec;
+import javax.crypto.spec.SecretKeySpec;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.primitives.Bytes;
+import org.apache.commons.crypto.cipher.CryptoCipher;
+import org.apache.commons.crypto.cipher.CryptoCipherFactory;
+import org.apache.commons.crypto.random.CryptoRandom;
+import org.apache.commons.crypto.random.CryptoRandomFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A helper class for abstracting authentication and key negotiation details. This is used by
+ * both client and server sides, since the operations are basically the same.
+ */
+class AuthEngine implements Closeable {
+
+ private static final Logger LOG = LoggerFactory.getLogger(AuthEngine.class);
+ private static final BigInteger ONE = new BigInteger(new byte[] { 0x1 });
+
+ private final byte[] appId;
+ private final char[] secret;
+ private final TransportConf conf;
+ private final Properties cryptoConf;
+ private final CryptoRandom random;
+
+ private byte[] authNonce;
+
+ @VisibleForTesting
+ byte[] challenge;
+
+ private TransportCipher sessionCipher;
+ private CryptoCipher encryptor;
+ private CryptoCipher decryptor;
+
+ AuthEngine(String appId, String secret, TransportConf conf) throws GeneralSecurityException {
+ this.appId = appId.getBytes(UTF_8);
+ this.conf = conf;
+ this.cryptoConf = conf.cryptoConf();
+ this.secret = secret.toCharArray();
+ this.random = CryptoRandomFactory.getCryptoRandom(cryptoConf);
+ }
+
+ /**
+ * Create the client challenge.
+ *
+ * @return A challenge to be sent the remote side.
+ */
+ ClientChallenge challenge() throws GeneralSecurityException, IOException {
+ this.authNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
+ SecretKeySpec authKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(),
+ authNonce, conf.encryptionKeyLength());
+ initializeForAuth(conf.cipherTransformation(), authNonce, authKey);
+
+ this.challenge = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
+ return new ClientChallenge(new String(appId, UTF_8),
+ conf.keyFactoryAlgorithm(),
+ conf.keyFactoryIterations(),
+ conf.cipherTransformation(),
+ conf.encryptionKeyLength(),
+ authNonce,
+ challenge(appId, authNonce, challenge));
+ }
+
+ /**
+ * Validates the client challenge, and create the encryption backend for the channel from the
+ * parameters sent by the client.
+ *
+ * @param clientChallenge The challenge from the client.
+ * @return A response to be sent to the client.
+ */
+ ServerResponse respond(ClientChallenge clientChallenge)
+ throws GeneralSecurityException, IOException {
+
+ SecretKeySpec authKey = generateKey(clientChallenge.kdf, clientChallenge.iterations,
+ clientChallenge.nonce, clientChallenge.keyLength);
+ initializeForAuth(clientChallenge.cipher, clientChallenge.nonce, authKey);
+
+ byte[] challenge = validateChallenge(clientChallenge.nonce, clientChallenge.challenge);
+ byte[] response = challenge(appId, clientChallenge.nonce, rawResponse(challenge));
+ byte[] sessionNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
+ byte[] inputIv = randomBytes(conf.ivLength());
+ byte[] outputIv = randomBytes(conf.ivLength());
+
+ SecretKeySpec sessionKey = generateKey(clientChallenge.kdf, clientChallenge.iterations,
+ sessionNonce, clientChallenge.keyLength);
+ this.sessionCipher = new TransportCipher(cryptoConf, clientChallenge.cipher, sessionKey,
+ inputIv, outputIv);
+
+ // Note the IVs are swapped in the response.
+ return new ServerResponse(response, encrypt(sessionNonce), encrypt(outputIv), encrypt(inputIv));
+ }
+
+ /**
+ * Validates the server response and initializes the cipher to use for the session.
+ *
+ * @param serverResponse The response from the server.
+ */
+ void validate(ServerResponse serverResponse) throws GeneralSecurityException {
+ byte[] response = validateChallenge(authNonce, serverResponse.response);
+
+ byte[] expected = rawResponse(challenge);
+ Preconditions.checkArgument(Arrays.equals(expected, response));
+
+ byte[] nonce = decrypt(serverResponse.nonce);
+ byte[] inputIv = decrypt(serverResponse.inputIv);
+ byte[] outputIv = decrypt(serverResponse.outputIv);
+
+ SecretKeySpec sessionKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(),
+ nonce, conf.encryptionKeyLength());
+ this.sessionCipher = new TransportCipher(cryptoConf, conf.cipherTransformation(), sessionKey,
+ inputIv, outputIv);
+ }
+
+ TransportCipher sessionCipher() {
+ Preconditions.checkState(sessionCipher != null);
+ return sessionCipher;
+ }
+
+ @Override
+ public void close() throws IOException {
+ // Close ciphers (by calling "doFinal()" with dummy data) and the random instance so that
+ // internal state is cleaned up. Error handling here is just for paranoia, and not meant to
+ // accurately report the errors when they happen.
+ RuntimeException error = null;
+ byte[] dummy = new byte[8];
+ try {
+ doCipherOp(encryptor, dummy, true);
+ } catch (Exception e) {
+ error = new RuntimeException(e);
+ }
+ try {
+ doCipherOp(decryptor, dummy, true);
+ } catch (Exception e) {
+ error = new RuntimeException(e);
+ }
+ random.close();
+
+ if (error != null) {
+ throw error;
+ }
+ }
+
+ @VisibleForTesting
+ byte[] challenge(byte[] appId, byte[] nonce, byte[] challenge) throws GeneralSecurityException {
+ return encrypt(Bytes.concat(appId, nonce, challenge));
+ }
+
+ @VisibleForTesting
+ byte[] rawResponse(byte[] challenge) {
+ BigInteger orig = new BigInteger(challenge);
+ BigInteger response = orig.add(ONE);
+ return response.toByteArray();
+ }
+
+ private byte[] decrypt(byte[] in) throws GeneralSecurityException {
+ return doCipherOp(decryptor, in, false);
+ }
+
+ private byte[] encrypt(byte[] in) throws GeneralSecurityException {
+ return doCipherOp(encryptor, in, false);
+ }
+
+ private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key)
+ throws GeneralSecurityException {
+
+ // commons-crypto currently only supports ciphers that require an initial vector; so
+ // create a dummy vector so that we can initialize the ciphers. In the future, if
+ // different ciphers are supported, this will have to be configurable somehow.
+ byte[] iv = new byte[conf.ivLength()];
+ System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length));
+
+ encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
+ encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv));
+
+ decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
+ decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv));
+ }
+
+ /**
+ * Validates an encrypted challenge as defined in the protocol, and returns the byte array
+ * that corresponds to the actual challenge data.
+ */
+ private byte[] validateChallenge(byte[] nonce, byte[] encryptedChallenge)
+ throws GeneralSecurityException {
+
+ byte[] challenge = decrypt(encryptedChallenge);
+ checkSubArray(appId, challenge, 0);
+ checkSubArray(nonce, challenge, appId.length);
+ return Arrays.copyOfRange(challenge, appId.length + nonce.length, challenge.length);
+ }
+
+ private SecretKeySpec generateKey(String kdf, int iterations, byte[] salt, int keyLength)
+ throws GeneralSecurityException {
+
+ SecretKeyFactory factory = SecretKeyFactory.getInstance(kdf);
+ PBEKeySpec spec = new PBEKeySpec(secret, salt, iterations, keyLength);
+
+ long start = System.nanoTime();
+ SecretKey key = factory.generateSecret(spec);
+ long end = System.nanoTime();
+
+ LOG.debug("Generated key with {} iterations in {} us.", conf.keyFactoryIterations(),
+ (end - start) / 1000);
+
+ return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm());
+ }
+
+ private byte[] doCipherOp(CryptoCipher cipher, byte[] in, boolean isFinal)
+ throws GeneralSecurityException {
+
+ Preconditions.checkState(cipher != null);
+
+ int scale = 1;
+ while (true) {
+ int size = in.length * scale;
+ byte[] buffer = new byte[size];
+ try {
+ int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0)
+ : cipher.update(in, 0, in.length, buffer, 0);
+ if (outSize != buffer.length) {
+ byte[] output = new byte[outSize];
+ System.arraycopy(buffer, 0, output, 0, output.length);
+ return output;
+ } else {
+ return buffer;
+ }
+ } catch (ShortBufferException e) {
+ // Try again with a bigger buffer.
+ scale *= 2;
+ }
+ }
+ }
+
+ private byte[] randomBytes(int count) {
+ byte[] bytes = new byte[count];
+ random.nextBytes(bytes);
+ return bytes;
+ }
+
+ /** Checks that the "test" array is in the data array starting at the given offset. */
+ private void checkSubArray(byte[] test, byte[] data, int offset) {
+ Preconditions.checkArgument(data.length >= test.length + offset);
+ for (int i = 0; i < test.length; i++) {
+ Preconditions.checkArgument(test[i] == data[i + offset]);
+ }
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
new file mode 100644
index 0000000000..991d8ba95f
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import javax.security.sasl.Sasl;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Throwables;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.sasl.SaslRpcHandler;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * RPC Handler which performs authentication using Spark's auth protocol before delegating to a
+ * child RPC handler. If the configuration allows, this handler will delegate messages to a SASL
+ * RPC handler for further authentication, to support for clients that do not support Spark's
+ * protocol.
+ *
+ * The delegate will only receive messages if the given connection has been successfully
+ * authenticated. A connection may be authenticated at most once.
+ */
+class AuthRpcHandler extends RpcHandler {
+ private static final Logger LOG = LoggerFactory.getLogger(AuthRpcHandler.class);
+
+ /** Transport configuration. */
+ private final TransportConf conf;
+
+ /** The client channel. */
+ private final Channel channel;
+
+ /**
+ * RpcHandler we will delegate to for authenticated connections. When falling back to SASL
+ * this will be replaced with the SASL RPC handler.
+ */
+ @VisibleForTesting
+ RpcHandler delegate;
+
+ /** Class which provides secret keys which are shared by server and client on a per-app basis. */
+ private final SecretKeyHolder secretKeyHolder;
+
+ /** Whether auth is done and future calls should be delegated. */
+ @VisibleForTesting
+ boolean doDelegate;
+
+ AuthRpcHandler(
+ TransportConf conf,
+ Channel channel,
+ RpcHandler delegate,
+ SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.channel = channel;
+ this.delegate = delegate;
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
+ if (doDelegate) {
+ delegate.receive(client, message, callback);
+ return;
+ }
+
+ int position = message.position();
+ int limit = message.limit();
+
+ ClientChallenge challenge;
+ try {
+ challenge = ClientChallenge.decodeMessage(message);
+ LOG.debug("Received new auth challenge for client {}.", channel.remoteAddress());
+ } catch (RuntimeException e) {
+ if (conf.saslFallback()) {
+ LOG.warn("Failed to parse new auth challenge, reverting to SASL for client {}.",
+ channel.remoteAddress());
+ delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder);
+ message.position(position);
+ message.limit(limit);
+ delegate.receive(client, message, callback);
+ doDelegate = true;
+ } else {
+ LOG.debug("Unexpected challenge message from client {}, closing channel.",
+ channel.remoteAddress());
+ callback.onFailure(new IllegalArgumentException("Unknown challenge message."));
+ channel.close();
+ }
+ return;
+ }
+
+ // Here we have the client challenge, so perform the new auth protocol and set up the channel.
+ AuthEngine engine = null;
+ try {
+ engine = new AuthEngine(challenge.appId, secretKeyHolder.getSecretKey(challenge.appId), conf);
+ ServerResponse response = engine.respond(challenge);
+ ByteBuf responseData = Unpooled.buffer(response.encodedLength());
+ response.encode(responseData);
+ callback.onSuccess(responseData.nioBuffer());
+ engine.sessionCipher().addToChannel(channel);
+ } catch (Exception e) {
+ // This is a fatal error: authentication has failed. Close the channel explicitly.
+ LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress());
+ callback.onFailure(new IllegalArgumentException("Authentication failed."));
+ channel.close();
+ return;
+ } finally {
+ if (engine != null) {
+ try {
+ engine.close();
+ } catch (Exception e) {
+ throw Throwables.propagate(e);
+ }
+ }
+ }
+
+ LOG.debug("Authorization successful for client {}.", channel.remoteAddress());
+ doDelegate = true;
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message) {
+ delegate.receive(client, message);
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return delegate.getStreamManager();
+ }
+
+ @Override
+ public void channelActive(TransportClient client) {
+ delegate.channelActive(client);
+ }
+
+ @Override
+ public void channelInactive(TransportClient client) {
+ delegate.channelInactive(client);
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause, TransportClient client) {
+ delegate.exceptionCaught(cause, client);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java
new file mode 100644
index 0000000000..77a2a6af4d
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import io.netty.channel.Channel;
+
+import org.apache.spark.network.sasl.SaslServerBootstrap;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a client connects
+ * to the server, enabling authentication using Spark's auth protocol (and optionally SASL for
+ * clients that don't support the new protocol).
+ *
+ * It also automatically falls back to SASL if the new encryption backend is disabled, so that
+ * callers only need to install this bootstrap when authentication is enabled.
+ */
+public class AuthServerBootstrap implements TransportServerBootstrap {
+
+ private final TransportConf conf;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public AuthServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
+ if (!conf.encryptionEnabled()) {
+ TransportServerBootstrap sasl = new SaslServerBootstrap(conf, secretKeyHolder);
+ return sasl.doBootstrap(channel, rpcHandler);
+ }
+
+ return new AuthRpcHandler(conf, channel, rpcHandler, secretKeyHolder);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java
new file mode 100644
index 0000000000..3312a5bd81
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.nio.ByteBuffer;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.Encoders;
+
+/**
+ * The client challenge message, used to initiate authentication.
+ *
+ * @see README.md
+ */
+public class ClientChallenge implements Encodable {
+ /** Serialization tag used to catch incorrect payloads. */
+ private static final byte TAG_BYTE = (byte) 0xFA;
+
+ public final String appId;
+ public final String kdf;
+ public final int iterations;
+ public final String cipher;
+ public final int keyLength;
+ public final byte[] nonce;
+ public final byte[] challenge;
+
+ public ClientChallenge(
+ String appId,
+ String kdf,
+ int iterations,
+ String cipher,
+ int keyLength,
+ byte[] nonce,
+ byte[] challenge) {
+ this.appId = appId;
+ this.kdf = kdf;
+ this.iterations = iterations;
+ this.cipher = cipher;
+ this.keyLength = keyLength;
+ this.nonce = nonce;
+ this.challenge = challenge;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 1 + 4 + 4 +
+ Encoders.Strings.encodedLength(appId) +
+ Encoders.Strings.encodedLength(kdf) +
+ Encoders.Strings.encodedLength(cipher) +
+ Encoders.ByteArrays.encodedLength(nonce) +
+ Encoders.ByteArrays.encodedLength(challenge);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeByte(TAG_BYTE);
+ Encoders.Strings.encode(buf, appId);
+ Encoders.Strings.encode(buf, kdf);
+ buf.writeInt(iterations);
+ Encoders.Strings.encode(buf, cipher);
+ buf.writeInt(keyLength);
+ Encoders.ByteArrays.encode(buf, nonce);
+ Encoders.ByteArrays.encode(buf, challenge);
+ }
+
+ public static ClientChallenge decodeMessage(ByteBuffer buffer) {
+ ByteBuf buf = Unpooled.wrappedBuffer(buffer);
+
+ if (buf.readByte() != TAG_BYTE) {
+ throw new IllegalArgumentException("Expected ClientChallenge, received something else.");
+ }
+
+ return new ClientChallenge(
+ Encoders.Strings.decode(buf),
+ Encoders.Strings.decode(buf),
+ buf.readInt(),
+ Encoders.Strings.decode(buf),
+ buf.readInt(),
+ Encoders.ByteArrays.decode(buf),
+ Encoders.ByteArrays.decode(buf));
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md
new file mode 100644
index 0000000000..14df703270
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md
@@ -0,0 +1,158 @@
+Spark Auth Protocol and AES Encryption Support
+==============================================
+
+This file describes an auth protocol used by Spark as a more secure alternative to DIGEST-MD5. This
+protocol is built on symmetric key encryption, based on the assumption that the two endpoints being
+authenticated share a common secret, which is how Spark authentication currently works. The protocol
+provides mutual authentication, meaning that after the negotiation both parties know that the remote
+side knows the shared secret. The protocol is influenced by the ISO/IEC 9798 protocol, although it's
+not an implementation of it.
+
+This protocol could be replaced with TLS PSK, except no PSK ciphers are available in the currently
+released JREs.
+
+The protocol aims at solving the following shortcomings in Spark's current usage of DIGEST-MD5:
+
+- MD5 is an aging hash algorithm with known weaknesses, and a more secure alternative is desired.
+- DIGEST-MD5 has a pre-defined set of ciphers for which it can generate keys. The only
+ viable, supported cipher these days is 3DES, and a more modern alternative is desired.
+- Encrypting AES session keys with 3DES doesn't solve the issue, since the weakest link
+ in the negotiation would still be MD5 and 3DES.
+
+The protocol assumes that the shared secret is generated and distributed in a secure manner.
+
+The protocol always negotiates encryption keys. If encryption is not desired, the existing
+SASL-based authentication, or no authentication at all, can be chosen instead.
+
+When messages are described below, it's expected that the implementation should support
+arbitrary sizes for fields that don't have a fixed size.
+
+Client Challenge
+----------------
+
+The auth negotiation is started by the client. The client starts by generating an encryption
+key based on the application's shared secret, and a nonce.
+
+ KEY = KDF(SECRET, SALT, KEY_LENGTH)
+
+Where:
+- KDF(): a key derivation function that takes a secret, a salt, a configurable number of
+ iterations, and a configurable key length.
+- SALT: a byte sequence used to salt the key derivation function.
+- KEY_LENGTH: length of the encryption key to generate.
+
+
+The client generates a message with the following content:
+
+ CLIENT_CHALLENGE = (
+ APP_ID,
+ KDF,
+ ITERATIONS,
+ CIPHER,
+ KEY_LENGTH,
+ ANONCE,
+ ENC(APP_ID || ANONCE || CHALLENGE))
+
+Where:
+
+- APP_ID: the application ID which the server uses to identify the shared secret.
+- KDF: the key derivation function described above.
+- ITERATIONS: number of iterations to run the KDF when generating keys.
+- CIPHER: the cipher used to encrypt data.
+- KEY_LENGTH: length of the encryption keys to generate, in bits.
+- ANONCE: the nonce used as the salt when generating the auth key.
+- ENC(): an encryption function that uses the cipher and the generated key. This function
+ will also be used in the definition of other messages below.
+- CHALLENGE: a byte sequence used as a challenge to the server.
+- ||: concatenation operator.
+
+When strings are used where byte arrays are expected, the UTF-8 representation of the string
+is assumed.
+
+To respond to the challenge, the server should consider the byte array as representing an
+arbitrary-length integer, and respond with the value of the integer plus one.
+
+
+Server Response And Challenge
+-----------------------------
+
+Once the client challenge is received, the server will generate the same auth key by
+using the same algorithm the client has used. It will then verify the client challenge:
+if the APP_ID and ANONCE fields match, the server knows that the client has the shared
+secret. The server then creates a response to the client challenge, to prove that it also
+has the secret key, and provides parameters to be used when creating the session key.
+
+The following describes the response from the server:
+
+ SERVER_CHALLENGE = (
+ ENC(APP_ID || ANONCE || RESPONSE),
+ ENC(SNONCE),
+ ENC(INIV),
+ ENC(OUTIV))
+
+Where:
+
+- RESPONSE: the server's response to the client challenge.
+- SNONCE: a nonce to be used as salt when generating the session key.
+- INIV: initialization vector used to initialize the input channel of the client.
+- OUTIV: initialization vector used to initialize the output channel of the client.
+
+At this point the server considers the client to be authenticated, and will try to
+decrypt any data further sent by the client using the session key.
+
+
+Default Algorithms
+------------------
+
+Configuration options are available for the KDF and cipher algorithms to use.
+
+The default KDF is "PBKDF2WithHmacSHA1". Users should be able to select any algorithm
+from those supported by the `javax.crypto.SecretKeyFactory` class, as long as they support
+PBEKeySpec when generating keys. The default number of iterations was chosen to take a
+reasonable amount of time on modern CPUs. See the documentation in TransportConf for more
+details.
+
+The default cipher algorithm is "AES/CTR/NoPadding". Users should be able to select any
+algorithm supported by the commons-crypto library. It should allow the cipher to operate
+in stream mode.
+
+The default key length is 128 (bits).
+
+
+Implementation Details
+----------------------
+
+The commons-crypto library currently only supports AES ciphers, and requires an initialization
+vector (IV). This first version of the protocol does not explicitly include the IV in the client
+challenge message. Instead, the IV should be derived from the nonce, including the needed bytes, and
+padding the IV with zeroes in case the nonce is not long enough.
+
+Future versions of the protocol might add support for new ciphers and explicitly include needed
+configuration parameters in the messages.
+
+
+Threat Assessment
+-----------------
+
+The protocol is secure against different forms of attack:
+
+* Eavesdropping: the protocol is built on the assumption that it's computationally infeasible
+ to calculate the original secret from the encrypted messages. Neither the secret nor any
+ encryption keys are transmitted on the wire, encrypted or not.
+
+* Man-in-the-middle: because the protocol performs mutual authentication, both ends need to
+ know the shared secret to be able to decrypt session data. Even if an attacker is able to insert a
+ malicious "proxy" between endpoints, the attacker won't be able to read any of the data exchanged
+ between client and server, nor insert arbitrary commands for the server to execute.
+
+* Replay attacks: the use of nonces when generating keys prevents an attacker from being able to
+ just replay messages sniffed from the communication channel.
+
+An attacker may replay the client challenge and successfully "prove" to a server that it "knows" the
+shared secret. But the attacker won't be able to decrypt the server's response, and thus won't be
+able to generate a session key, which will make it hard to craft a valid, encrypted message that the
+server will be able to understand. This will cause the server to close the connection as soon as the
+attacker tries to send any command to the server. The attacker can just hold the channel open for
+some time, which will be closed when the server times out the channel. These issues could be
+separately mitigated by adding a shorter timeout for the first message after authentication, and
+potentially by adding host blacklists if a possible attack is detected from a particular host.
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java
new file mode 100644
index 0000000000..affdbf450b
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.nio.ByteBuffer;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.Encoders;
+
+/**
+ * Server's response to client's challenge.
+ *
+ * @see README.md
+ */
+public class ServerResponse implements Encodable {
+ /** Serialization tag used to catch incorrect payloads. */
+ private static final byte TAG_BYTE = (byte) 0xFB;
+
+ public final byte[] response;
+ public final byte[] nonce;
+ public final byte[] inputIv;
+ public final byte[] outputIv;
+
+ public ServerResponse(
+ byte[] response,
+ byte[] nonce,
+ byte[] inputIv,
+ byte[] outputIv) {
+ this.response = response;
+ this.nonce = nonce;
+ this.inputIv = inputIv;
+ this.outputIv = outputIv;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 1 +
+ Encoders.ByteArrays.encodedLength(response) +
+ Encoders.ByteArrays.encodedLength(nonce) +
+ Encoders.ByteArrays.encodedLength(inputIv) +
+ Encoders.ByteArrays.encodedLength(outputIv);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeByte(TAG_BYTE);
+ Encoders.ByteArrays.encode(buf, response);
+ Encoders.ByteArrays.encode(buf, nonce);
+ Encoders.ByteArrays.encode(buf, inputIv);
+ Encoders.ByteArrays.encode(buf, outputIv);
+ }
+
+ public static ServerResponse decodeMessage(ByteBuffer buffer) {
+ ByteBuf buf = Unpooled.wrappedBuffer(buffer);
+
+ if (buf.readByte() != TAG_BYTE) {
+ throw new IllegalArgumentException("Expected ServerResponse, received something else.");
+ }
+
+ return new ServerResponse(
+ Encoders.ByteArrays.decode(buf),
+ Encoders.ByteArrays.decode(buf),
+ Encoders.ByteArrays.decode(buf),
+ Encoders.ByteArrays.decode(buf));
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
index 340986a63b..7376d1ddc4 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.network.sasl.aes;
+package org.apache.spark.network.crypto;
import java.io.IOException;
import java.nio.ByteBuffer;
@@ -25,115 +25,91 @@ import java.util.Properties;
import javax.crypto.spec.SecretKeySpec;
import javax.crypto.spec.IvParameterSpec;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
-import com.google.common.base.Throwables;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.util.AbstractReferenceCounted;
-import org.apache.commons.crypto.cipher.CryptoCipherFactory;
-import org.apache.commons.crypto.random.CryptoRandom;
-import org.apache.commons.crypto.random.CryptoRandomFactory;
import org.apache.commons.crypto.stream.CryptoInputStream;
import org.apache.commons.crypto.stream.CryptoOutputStream;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
import org.apache.spark.network.util.ByteArrayReadableChannel;
import org.apache.spark.network.util.ByteArrayWritableChannel;
-import org.apache.spark.network.util.TransportConf;
/**
- * AES cipher for encryption and decryption.
+ * Cipher for encryption and decryption.
*/
-public class AesCipher {
- private static final Logger logger = LoggerFactory.getLogger(AesCipher.class);
- public static final String ENCRYPTION_HANDLER_NAME = "AesEncryption";
- public static final String DECRYPTION_HANDLER_NAME = "AesDecryption";
- public static final int STREAM_BUFFER_SIZE = 1024 * 32;
- public static final String TRANSFORM = "AES/CTR/NoPadding";
-
- private final SecretKeySpec inKeySpec;
- private final IvParameterSpec inIvSpec;
- private final SecretKeySpec outKeySpec;
- private final IvParameterSpec outIvSpec;
- private final Properties properties;
-
- public AesCipher(AesConfigMessage configMessage, TransportConf conf) throws IOException {
- this.properties = conf.cryptoConf();
- this.inKeySpec = new SecretKeySpec(configMessage.inKey, "AES");
- this.inIvSpec = new IvParameterSpec(configMessage.inIv);
- this.outKeySpec = new SecretKeySpec(configMessage.outKey, "AES");
- this.outIvSpec = new IvParameterSpec(configMessage.outIv);
+public class TransportCipher {
+ @VisibleForTesting
+ static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption";
+ private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption";
+ private static final int STREAM_BUFFER_SIZE = 1024 * 32;
+
+ private final Properties conf;
+ private final String cipher;
+ private final SecretKeySpec key;
+ private final byte[] inIv;
+ private final byte[] outIv;
+
+ public TransportCipher(
+ Properties conf,
+ String cipher,
+ SecretKeySpec key,
+ byte[] inIv,
+ byte[] outIv) {
+ this.conf = conf;
+ this.cipher = cipher;
+ this.key = key;
+ this.inIv = inIv;
+ this.outIv = outIv;
+ }
+
+ public String getCipherTransformation() {
+ return cipher;
+ }
+
+ @VisibleForTesting
+ SecretKeySpec getKey() {
+ return key;
+ }
+
+ /** The IV for the input channel (i.e. output channel of the remote side). */
+ public byte[] getInputIv() {
+ return inIv;
+ }
+
+ /** The IV for the output channel (i.e. input channel of the remote side). */
+ public byte[] getOutputIv() {
+ return outIv;
}
- /**
- * Create AES crypto output stream
- * @param ch The underlying channel to write out.
- * @return Return output crypto stream for encryption.
- * @throws IOException
- */
private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException {
- return new CryptoOutputStream(TRANSFORM, properties, ch, outKeySpec, outIvSpec);
+ return new CryptoOutputStream(cipher, conf, ch, key, new IvParameterSpec(outIv));
}
- /**
- * Create AES crypto input stream
- * @param ch The underlying channel used to read data.
- * @return Return input crypto stream for decryption.
- * @throws IOException
- */
private CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException {
- return new CryptoInputStream(TRANSFORM, properties, ch, inKeySpec, inIvSpec);
+ return new CryptoInputStream(cipher, conf, ch, key, new IvParameterSpec(inIv));
}
/**
- * Add handlers to channel
+ * Add handlers to channel.
+ *
* @param ch the channel for adding handlers
* @throws IOException
*/
public void addToChannel(Channel ch) throws IOException {
ch.pipeline()
- .addFirst(ENCRYPTION_HANDLER_NAME, new AesEncryptHandler(this))
- .addFirst(DECRYPTION_HANDLER_NAME, new AesDecryptHandler(this));
- }
-
- /**
- * Create the configuration message
- * @param conf is the local transport configuration.
- * @return Config message for sending.
- */
- public static AesConfigMessage createConfigMessage(TransportConf conf) {
- int keySize = conf.aesCipherKeySize();
- Properties properties = conf.cryptoConf();
-
- try {
- int paramLen = CryptoCipherFactory.getCryptoCipher(AesCipher.TRANSFORM, properties)
- .getBlockSize();
- byte[] inKey = new byte[keySize];
- byte[] outKey = new byte[keySize];
- byte[] inIv = new byte[paramLen];
- byte[] outIv = new byte[paramLen];
-
- CryptoRandom random = CryptoRandomFactory.getCryptoRandom(properties);
- random.nextBytes(inKey);
- random.nextBytes(outKey);
- random.nextBytes(inIv);
- random.nextBytes(outIv);
-
- return new AesConfigMessage(inKey, inIv, outKey, outIv);
- } catch (Exception e) {
- logger.error("AES config error", e);
- throw Throwables.propagate(e);
- }
+ .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(this))
+ .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this));
}
- private static class AesEncryptHandler extends ChannelOutboundHandlerAdapter {
+ private static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
private final ByteArrayWritableChannel byteChannel;
private final CryptoOutputStream cos;
- AesEncryptHandler(AesCipher cipher) throws IOException {
- byteChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE);
+ EncryptionHandler(TransportCipher cipher) throws IOException {
+ byteChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
cos = cipher.createOutputStream(byteChannel);
}
@@ -153,11 +129,11 @@ public class AesCipher {
}
}
- private static class AesDecryptHandler extends ChannelInboundHandlerAdapter {
+ private static class DecryptionHandler extends ChannelInboundHandlerAdapter {
private final CryptoInputStream cis;
private final ByteArrayReadableChannel byteChannel;
- AesDecryptHandler(AesCipher cipher) throws IOException {
+ DecryptionHandler(TransportCipher cipher) throws IOException {
byteChannel = new ByteArrayReadableChannel();
cis = cipher.createInputStream(byteChannel);
}
@@ -207,7 +183,7 @@ public class AesCipher {
this.buf = isByteBuf ? (ByteBuf) msg : null;
this.region = isByteBuf ? null : (FileRegion) msg;
this.transferred = 0;
- this.byteRawChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE);
+ this.byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
this.cos = cos;
this.byteEncChannel = ch;
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index a1bb453657..6478137722 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -30,8 +30,6 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
-import org.apache.spark.network.sasl.aes.AesCipher;
-import org.apache.spark.network.sasl.aes.AesConfigMessage;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.TransportConf;
@@ -42,24 +40,14 @@ import org.apache.spark.network.util.TransportConf;
public class SaslClientBootstrap implements TransportClientBootstrap {
private static final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);
- private final boolean encrypt;
private final TransportConf conf;
private final String appId;
private final SecretKeyHolder secretKeyHolder;
public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
- this(conf, appId, secretKeyHolder, false);
- }
-
- public SaslClientBootstrap(
- TransportConf conf,
- String appId,
- SecretKeyHolder secretKeyHolder,
- boolean encrypt) {
this.conf = conf;
this.appId = appId;
this.secretKeyHolder = secretKeyHolder;
- this.encrypt = encrypt;
}
/**
@@ -69,7 +57,7 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
*/
@Override
public void doBootstrap(TransportClient client, Channel channel) {
- SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt);
+ SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption());
try {
byte[] payload = saslClient.firstToken();
@@ -79,35 +67,19 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
msg.encode(buf);
buf.writeBytes(msg.body().nioByteBuffer());
- ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs());
+ ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs());
payload = saslClient.response(JavaUtils.bufferToArray(response));
}
client.setClientId(appId);
- if (encrypt) {
+ if (conf.saslEncryption()) {
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
throw new RuntimeException(
new SaslException("Encryption requests by negotiated non-encrypted connection."));
}
- if (conf.aesEncryptionEnabled()) {
- // Generate a request config message to send to server.
- AesConfigMessage configMessage = AesCipher.createConfigMessage(conf);
- ByteBuffer buf = configMessage.encodeMessage();
-
- // Encrypted the config message.
- byte[] toEncrypt = JavaUtils.bufferToArray(buf);
- ByteBuffer encrypted = ByteBuffer.wrap(saslClient.wrap(toEncrypt, 0, toEncrypt.length));
-
- client.sendRpcSync(encrypted, conf.saslRTTimeoutMs());
- AesCipher cipher = new AesCipher(configMessage, conf);
- logger.info("Enabling AES cipher for client channel {}", client);
- cipher.addToChannel(channel);
- saslClient.dispose();
- } else {
- SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
- }
+ SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
saslClient = null;
logger.debug("Channel {} configured for encryption.", client);
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index b2f3ef214b..0231428318 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -29,8 +29,6 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.sasl.aes.AesCipher;
-import org.apache.spark.network.sasl.aes.AesConfigMessage;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.JavaUtils;
@@ -44,7 +42,7 @@ import org.apache.spark.network.util.TransportConf;
* Note that the authentication process consists of multiple challenge-response pairs, each of
* which are individual RPCs.
*/
-class SaslRpcHandler extends RpcHandler {
+public class SaslRpcHandler extends RpcHandler {
private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
/** Transport configuration. */
@@ -63,7 +61,7 @@ class SaslRpcHandler extends RpcHandler {
private boolean isComplete;
private boolean isAuthenticated;
- SaslRpcHandler(
+ public SaslRpcHandler(
TransportConf conf,
Channel channel,
RpcHandler delegate,
@@ -122,37 +120,10 @@ class SaslRpcHandler extends RpcHandler {
return;
}
- if (!conf.aesEncryptionEnabled()) {
- logger.debug("Enabling encryption for channel {}", client);
- SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
- complete(false);
- return;
- }
-
- // Extra negotiation should happen after authentication, so return directly while
- // processing authenticate.
- if (!isAuthenticated) {
- logger.debug("SASL authentication successful for channel {}", client);
- isAuthenticated = true;
- return;
- }
-
- // Create AES cipher when it is authenticated
- try {
- byte[] encrypted = JavaUtils.bufferToArray(message);
- ByteBuffer decrypted = ByteBuffer.wrap(saslServer.unwrap(encrypted, 0 , encrypted.length));
-
- AesConfigMessage configMessage = AesConfigMessage.decodeMessage(decrypted);
- AesCipher cipher = new AesCipher(configMessage, conf);
-
- // Send response back to client to confirm that server accept config.
- callback.onSuccess(JavaUtils.stringToBytes(AesCipher.TRANSFORM));
- logger.info("Enabling AES cipher for Server channel {}", client);
- cipher.addToChannel(channel);
- complete(true);
- } catch (IOException ioe) {
- throw new RuntimeException(ioe);
- }
+ logger.debug("Enabling encryption for channel {}", client);
+ SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
+ complete(false);
+ return;
}
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java
deleted file mode 100644
index 3ef6f74a1f..0000000000
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.sasl.aes;
-
-import java.nio.ByteBuffer;
-
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-
-import org.apache.spark.network.protocol.Encodable;
-import org.apache.spark.network.protocol.Encoders;
-
-/**
- * The AES cipher options for encryption negotiation.
- */
-public class AesConfigMessage implements Encodable {
- /** Serialization tag used to catch incorrect payloads. */
- private static final byte TAG_BYTE = (byte) 0xEB;
-
- public byte[] inKey;
- public byte[] outKey;
- public byte[] inIv;
- public byte[] outIv;
-
- public AesConfigMessage(byte[] inKey, byte[] inIv, byte[] outKey, byte[] outIv) {
- if (inKey == null || inIv == null || outKey == null || outIv == null) {
- throw new IllegalArgumentException("Cipher Key or IV must not be null!");
- }
-
- this.inKey = inKey;
- this.inIv = inIv;
- this.outKey = outKey;
- this.outIv = outIv;
- }
-
- @Override
- public int encodedLength() {
- return 1 +
- Encoders.ByteArrays.encodedLength(inKey) + Encoders.ByteArrays.encodedLength(outKey) +
- Encoders.ByteArrays.encodedLength(inIv) + Encoders.ByteArrays.encodedLength(outIv);
- }
-
- @Override
- public void encode(ByteBuf buf) {
- buf.writeByte(TAG_BYTE);
- Encoders.ByteArrays.encode(buf, inKey);
- Encoders.ByteArrays.encode(buf, inIv);
- Encoders.ByteArrays.encode(buf, outKey);
- Encoders.ByteArrays.encode(buf, outIv);
- }
-
- /**
- * Encode the config message.
- * @return ByteBuffer which contains encoded config message.
- */
- public ByteBuffer encodeMessage(){
- ByteBuffer buf = ByteBuffer.allocate(encodedLength());
-
- ByteBuf wrappedBuf = Unpooled.wrappedBuffer(buf);
- wrappedBuf.clear();
- encode(wrappedBuf);
-
- return buf;
- }
-
- /**
- * Decode the config message from buffer
- * @param buffer the buffer contain encoded config message
- * @return config message
- */
- public static AesConfigMessage decodeMessage(ByteBuffer buffer) {
- ByteBuf buf = Unpooled.wrappedBuffer(buffer);
-
- if (buf.readByte() != TAG_BYTE) {
- throw new IllegalStateException("Expected AesConfigMessage, received something else"
- + " (maybe your client does not have AES enabled?)");
- }
-
- byte[] outKey = Encoders.ByteArrays.decode(buf);
- byte[] outIv = Encoders.ByteArrays.decode(buf);
- byte[] inKey = Encoders.ByteArrays.decode(buf);
- byte[] inIv = Encoders.ByteArrays.decode(buf);
- return new AesConfigMessage(inKey, inIv, outKey, outIv);
- }
-
-}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 6a557fa75d..c226d8f3bc 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -117,9 +117,10 @@ public class TransportConf {
/** Send buffer size (SO_SNDBUF). */
public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); }
- /** Timeout for a single round trip of SASL token exchange, in milliseconds. */
- public int saslRTTimeoutMs() {
- return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000;
+ /** Timeout for a single round trip of auth message exchange, in milliseconds. */
+ public int authRTTimeoutMs() {
+ return (int) JavaUtils.timeStringAsSec(conf.get("spark.network.auth.rpcTimeout",
+ conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s"))) * 1000;
}
/**
@@ -162,40 +163,95 @@ public class TransportConf {
}
/**
- * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled.
+ * Enables strong encryption. Also enables the new auth protocol, used to negotiate keys.
*/
- public int maxSaslEncryptedBlockSize() {
- return Ints.checkedCast(JavaUtils.byteStringAsBytes(
- conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k")));
+ public boolean encryptionEnabled() {
+ return conf.getBoolean("spark.network.crypto.enabled", false);
}
/**
- * Whether the server should enforce encryption on SASL-authenticated connections.
+ * The cipher transformation to use for encrypting session data.
*/
- public boolean saslServerAlwaysEncrypt() {
- return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
+ public String cipherTransformation() {
+ return conf.get("spark.network.crypto.cipher", "AES/CTR/NoPadding");
+ }
+
+ /**
+ * The key generation algorithm. This should be an algorithm that accepts a "PBEKeySpec"
+ * as input. The default value (PBKDF2WithHmacSHA1) is available in Java 7.
+ */
+ public String keyFactoryAlgorithm() {
+ return conf.get("spark.network.crypto.keyFactoryAlgorithm", "PBKDF2WithHmacSHA1");
+ }
+
+ /**
+ * How many iterations to run when generating keys.
+ *
+ * See some discussion about this at: http://security.stackexchange.com/q/3959
+ * The default value was picked for speed, since it assumes that the secret has good entropy
+ * (128 bits by default), which is not generally the case with user passwords.
+ */
+ public int keyFactoryIterations() {
+ return conf.getInt("spark.networy.crypto.keyFactoryIterations", 1024);
+ }
+
+ /**
+ * Encryption key length, in bits.
+ */
+ public int encryptionKeyLength() {
+ return conf.getInt("spark.network.crypto.keyLength", 128);
+ }
+
+ /**
+ * Initial vector length, in bytes.
+ */
+ public int ivLength() {
+ return conf.getInt("spark.network.crypto.ivLength", 16);
+ }
+
+ /**
+ * The algorithm for generated secret keys. Nobody should really need to change this,
+ * but configurable just in case.
+ */
+ public String keyAlgorithm() {
+ return conf.get("spark.network.crypto.keyAlgorithm", "AES");
+ }
+
+ /**
+ * Whether to fall back to SASL if the new auth protocol fails. Enabled by default for
+ * backwards compatibility.
+ */
+ public boolean saslFallback() {
+ return conf.getBoolean("spark.network.crypto.saslFallback", true);
}
/**
- * The trigger for enabling AES encryption.
+ * Whether to enable SASL-based encryption when authenticating using SASL.
*/
- public boolean aesEncryptionEnabled() {
- return conf.getBoolean("spark.network.aes.enabled", false);
+ public boolean saslEncryption() {
+ return conf.getBoolean("spark.authenticate.enableSaslEncryption", false);
}
/**
- * The key size to use when AES cipher is enabled. Notice that the length should be 16, 24 or 32
- * bytes.
+ * Maximum number of bytes to be encrypted at a time when SASL encryption is used.
*/
- public int aesCipherKeySize() {
- return conf.getInt("spark.network.aes.keySize", 16);
+ public int maxSaslEncryptedBlockSize() {
+ return Ints.checkedCast(JavaUtils.byteStringAsBytes(
+ conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k")));
+ }
+
+ /**
+ * Whether the server should enforce encryption on SASL-authenticated connections.
+ */
+ public boolean saslServerAlwaysEncrypt() {
+ return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
}
/**
* The commons-crypto configuration for the module.
*/
public Properties cryptoConf() {
- return CryptoUtils.toCryptoConf("spark.network.aes.config.", conf.getAll());
+ return CryptoUtils.toCryptoConf("spark.network.crypto.config.", conf.getAll());
}
}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
new file mode 100644
index 0000000000..9a186f2113
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.util.Arrays;
+import java.util.Map;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+import com.google.common.collect.ImmutableMap;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.util.MapConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class AuthEngineSuite {
+
+ private static TransportConf conf;
+
+ @BeforeClass
+ public static void setUp() {
+ conf = new TransportConf("rpc", MapConfigProvider.EMPTY);
+ }
+
+ @Test
+ public void testAuthEngine() throws Exception {
+ AuthEngine client = new AuthEngine("appId", "secret", conf);
+ AuthEngine server = new AuthEngine("appId", "secret", conf);
+
+ try {
+ ClientChallenge clientChallenge = client.challenge();
+ ServerResponse serverResponse = server.respond(clientChallenge);
+ client.validate(serverResponse);
+
+ TransportCipher serverCipher = server.sessionCipher();
+ TransportCipher clientCipher = client.sessionCipher();
+
+ assertTrue(Arrays.equals(serverCipher.getInputIv(), clientCipher.getOutputIv()));
+ assertTrue(Arrays.equals(serverCipher.getOutputIv(), clientCipher.getInputIv()));
+ assertEquals(serverCipher.getKey(), clientCipher.getKey());
+ } finally {
+ client.close();
+ server.close();
+ }
+ }
+
+ @Test
+ public void testMismatchedSecret() throws Exception {
+ AuthEngine client = new AuthEngine("appId", "secret", conf);
+ AuthEngine server = new AuthEngine("appId", "different_secret", conf);
+
+ ClientChallenge clientChallenge = client.challenge();
+ try {
+ server.respond(clientChallenge);
+ fail("Should have failed to validate response.");
+ } catch (IllegalArgumentException e) {
+ // Expected.
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testWrongAppId() throws Exception {
+ AuthEngine engine = new AuthEngine("appId", "secret", conf);
+ ClientChallenge challenge = engine.challenge();
+
+ byte[] badChallenge = engine.challenge(new byte[] { 0x00 }, challenge.nonce,
+ engine.rawResponse(engine.challenge));
+ engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
+ challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testWrongNonce() throws Exception {
+ AuthEngine engine = new AuthEngine("appId", "secret", conf);
+ ClientChallenge challenge = engine.challenge();
+
+ byte[] badChallenge = engine.challenge(challenge.appId.getBytes(UTF_8), new byte[] { 0x00 },
+ engine.rawResponse(engine.challenge));
+ engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
+ challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testBadChallenge() throws Exception {
+ AuthEngine engine = new AuthEngine("appId", "secret", conf);
+ ClientChallenge challenge = engine.challenge();
+
+ byte[] badChallenge = new byte[challenge.challenge.length];
+ engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
+ challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
+ }
+
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
new file mode 100644
index 0000000000..21609d5aa2
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import io.netty.channel.Channel;
+import org.junit.After;
+import org.junit.Test;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.sasl.SaslRpcHandler;
+import org.apache.spark.network.sasl.SaslServerBootstrap;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.MapConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class AuthIntegrationSuite {
+
+ private AuthTestCtx ctx;
+
+ @After
+ public void cleanUp() throws Exception {
+ if (ctx != null) {
+ ctx.close();
+ }
+ ctx = null;
+ }
+
+ @Test
+ public void testNewAuth() throws Exception {
+ ctx = new AuthTestCtx();
+ ctx.createServer("secret");
+ ctx.createClient("secret");
+
+ ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
+ assertEquals("Pong", JavaUtils.bytesToString(reply));
+ assertTrue(ctx.authRpcHandler.doDelegate);
+ assertFalse(ctx.authRpcHandler.delegate instanceof SaslRpcHandler);
+ }
+
+ @Test
+ public void testAuthFailure() throws Exception {
+ ctx = new AuthTestCtx();
+ ctx.createServer("server");
+
+ try {
+ ctx.createClient("client");
+ fail("Should have failed to create client.");
+ } catch (Exception e) {
+ assertFalse(ctx.authRpcHandler.doDelegate);
+ assertFalse(ctx.serverChannel.isActive());
+ }
+ }
+
+ @Test
+ public void testSaslServerFallback() throws Exception {
+ ctx = new AuthTestCtx();
+ ctx.createServer("secret", true);
+ ctx.createClient("secret", false);
+
+ ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
+ assertEquals("Pong", JavaUtils.bytesToString(reply));
+ }
+
+ @Test
+ public void testSaslClientFallback() throws Exception {
+ ctx = new AuthTestCtx();
+ ctx.createServer("secret", false);
+ ctx.createClient("secret", true);
+
+ ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
+ assertEquals("Pong", JavaUtils.bytesToString(reply));
+ }
+
+ @Test
+ public void testAuthReplay() throws Exception {
+ // This test covers the case where an attacker replays a challenge message sniffed from the
+ // network, but doesn't know the actual secret. The server should close the connection as
+ // soon as a message is sent after authentication is performed. This is emulated by removing
+ // the client encryption handler after authentication.
+ ctx = new AuthTestCtx();
+ ctx.createServer("secret");
+ ctx.createClient("secret");
+
+ assertNotNull(ctx.client.getChannel().pipeline()
+ .remove(TransportCipher.ENCRYPTION_HANDLER_NAME));
+
+ try {
+ ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
+ fail("Should have failed unencrypted RPC.");
+ } catch (Exception e) {
+ assertTrue(ctx.authRpcHandler.doDelegate);
+ }
+ }
+
+ private class AuthTestCtx {
+
+ private final String appId = "testAppId";
+ private final TransportConf conf;
+ private final TransportContext ctx;
+
+ TransportClient client;
+ TransportServer server;
+ volatile Channel serverChannel;
+ volatile AuthRpcHandler authRpcHandler;
+
+ AuthTestCtx() throws Exception {
+ Map<String, String> testConf = ImmutableMap.of("spark.network.crypto.enabled", "true");
+ this.conf = new TransportConf("rpc", new MapConfigProvider(testConf));
+
+ RpcHandler rpcHandler = new RpcHandler() {
+ @Override
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ assertEquals("Ping", JavaUtils.bytesToString(message));
+ callback.onSuccess(JavaUtils.stringToBytes("Pong"));
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return null;
+ }
+ };
+
+ this.ctx = new TransportContext(conf, rpcHandler);
+ }
+
+ void createServer(String secret) throws Exception {
+ createServer(secret, true);
+ }
+
+ void createServer(String secret, boolean enableAes) throws Exception {
+ TransportServerBootstrap introspector = new TransportServerBootstrap() {
+ @Override
+ public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
+ AuthTestCtx.this.serverChannel = channel;
+ if (rpcHandler instanceof AuthRpcHandler) {
+ AuthTestCtx.this.authRpcHandler = (AuthRpcHandler) rpcHandler;
+ }
+ return rpcHandler;
+ }
+ };
+ SecretKeyHolder keyHolder = createKeyHolder(secret);
+ TransportServerBootstrap auth = enableAes ? new AuthServerBootstrap(conf, keyHolder)
+ : new SaslServerBootstrap(conf, keyHolder);
+ this.server = ctx.createServer(Lists.newArrayList(auth, introspector));
+ }
+
+ void createClient(String secret) throws Exception {
+ createClient(secret, true);
+ }
+
+ void createClient(String secret, boolean enableAes) throws Exception {
+ TransportConf clientConf = enableAes ? conf
+ : new TransportConf("rpc", MapConfigProvider.EMPTY);
+ List<TransportClientBootstrap> bootstraps = Lists.<TransportClientBootstrap>newArrayList(
+ new AuthClientBootstrap(clientConf, appId, createKeyHolder(secret)));
+ this.client = ctx.createClientFactory(bootstraps)
+ .createClient(TestUtils.getLocalHost(), server.getPort());
+ }
+
+ void close() {
+ if (client != null) {
+ client.close();
+ }
+ if (server != null) {
+ server.close();
+ }
+ }
+
+ private SecretKeyHolder createKeyHolder(String secret) {
+ SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
+ when(keyHolder.getSaslUser(anyString())).thenReturn(appId);
+ when(keyHolder.getSecretKey(anyString())).thenReturn(secret);
+ return keyHolder;
+ }
+
+ }
+
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java
new file mode 100644
index 0000000000..a90ff247da
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.protocol.Encodable;
+
+public class AuthMessagesSuite {
+
+ private static int COUNTER = 0;
+
+ private static String string() {
+ return String.valueOf(COUNTER++);
+ }
+
+ private static byte[] byteArray() {
+ byte[] bytes = new byte[COUNTER++];
+ for (int i = 0; i < bytes.length; i++) {
+ bytes[i] = (byte) COUNTER;
+ } return bytes;
+ }
+
+ private static int integer() {
+ return COUNTER++;
+ }
+
+ @Test
+ public void testClientChallenge() {
+ ClientChallenge msg = new ClientChallenge(string(), string(), integer(), string(), integer(),
+ byteArray(), byteArray());
+ ClientChallenge decoded = ClientChallenge.decodeMessage(encode(msg));
+
+ assertEquals(msg.appId, decoded.appId);
+ assertEquals(msg.kdf, decoded.kdf);
+ assertEquals(msg.iterations, decoded.iterations);
+ assertEquals(msg.cipher, decoded.cipher);
+ assertEquals(msg.keyLength, decoded.keyLength);
+ assertTrue(Arrays.equals(msg.nonce, decoded.nonce));
+ assertTrue(Arrays.equals(msg.challenge, decoded.challenge));
+ }
+
+ @Test
+ public void testServerResponse() {
+ ServerResponse msg = new ServerResponse(byteArray(), byteArray(), byteArray(), byteArray());
+ ServerResponse decoded = ServerResponse.decodeMessage(encode(msg));
+ assertTrue(Arrays.equals(msg.response, decoded.response));
+ assertTrue(Arrays.equals(msg.nonce, decoded.nonce));
+ assertTrue(Arrays.equals(msg.inputIv, decoded.inputIv));
+ assertTrue(Arrays.equals(msg.outputIv, decoded.outputIv));
+ }
+
+ private ByteBuffer encode(Encodable msg) {
+ ByteBuf buf = Unpooled.buffer();
+ msg.encode(buf);
+ return buf.nioBuffer();
+ }
+
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index e27301f49e..87129b900b 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -56,7 +56,6 @@ import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
-import org.apache.spark.network.sasl.aes.AesCipher;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
@@ -153,7 +152,7 @@ public class SparkSaslSuite {
.when(rpcHandler)
.receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class));
- SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false, false);
+ SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
try {
ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
TimeUnit.SECONDS.toMillis(10));
@@ -279,7 +278,7 @@ public class SparkSaslSuite {
new Random().nextBytes(data);
Files.write(data, file);
- ctx = new SaslTestCtx(rpcHandler, true, false, false, testConf);
+ ctx = new SaslTestCtx(rpcHandler, true, false, testConf);
final CountDownLatch lock = new CountDownLatch(1);
@@ -317,7 +316,7 @@ public class SparkSaslSuite {
public void testServerAlwaysEncrypt() throws Exception {
SaslTestCtx ctx = null;
try {
- ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, false,
+ ctx = new SaslTestCtx(mock(RpcHandler.class), false, false,
ImmutableMap.of("spark.network.sasl.serverAlwaysEncrypt", "true"));
fail("Should have failed to connect without encryption.");
} catch (Exception e) {
@@ -336,7 +335,7 @@ public class SparkSaslSuite {
// able to understand RPCs sent to it and thus close the connection.
SaslTestCtx ctx = null;
try {
- ctx = new SaslTestCtx(mock(RpcHandler.class), true, true, false);
+ ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
TimeUnit.SECONDS.toMillis(10));
fail("Should have failed to send RPC to server.");
@@ -374,69 +373,6 @@ public class SparkSaslSuite {
}
}
- @Test
- public void testAesEncryption() throws Exception {
- final AtomicReference<ManagedBuffer> response = new AtomicReference<>();
- final File file = File.createTempFile("sasltest", ".txt");
- SaslTestCtx ctx = null;
- try {
- final TransportConf conf = new TransportConf("rpc", MapConfigProvider.EMPTY);
- final TransportConf spyConf = spy(conf);
- doReturn(true).when(spyConf).aesEncryptionEnabled();
-
- StreamManager sm = mock(StreamManager.class);
- when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer<ManagedBuffer>() {
- @Override
- public ManagedBuffer answer(InvocationOnMock invocation) {
- return new FileSegmentManagedBuffer(spyConf, file, 0, file.length());
- }
- });
-
- RpcHandler rpcHandler = mock(RpcHandler.class);
- when(rpcHandler.getStreamManager()).thenReturn(sm);
-
- byte[] data = new byte[256 * 1024 * 1024];
- new Random().nextBytes(data);
- Files.write(data, file);
-
- ctx = new SaslTestCtx(rpcHandler, true, false, true);
-
- final Object lock = new Object();
-
- ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
- doAnswer(new Answer<Void>() {
- @Override
- public Void answer(InvocationOnMock invocation) {
- response.set((ManagedBuffer) invocation.getArguments()[1]);
- response.get().retain();
- synchronized (lock) {
- lock.notifyAll();
- }
- return null;
- }
- }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
-
- synchronized (lock) {
- ctx.client.fetchChunk(0, 0, callback);
- lock.wait(10 * 1000);
- }
-
- verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
- verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
-
- byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
- assertTrue(Arrays.equals(data, received));
- } finally {
- file.delete();
- if (ctx != null) {
- ctx.close();
- }
- if (response.get() != null) {
- response.get().release();
- }
- }
- }
-
private static class SaslTestCtx {
final TransportClient client;
@@ -449,46 +385,39 @@ public class SparkSaslSuite {
SaslTestCtx(
RpcHandler rpcHandler,
boolean encrypt,
- boolean disableClientEncryption,
- boolean aesEnable)
+ boolean disableClientEncryption)
throws Exception {
- this(rpcHandler, encrypt, disableClientEncryption, aesEnable,
- Collections.<String, String>emptyMap());
+ this(rpcHandler, encrypt, disableClientEncryption, Collections.<String, String>emptyMap());
}
SaslTestCtx(
RpcHandler rpcHandler,
boolean encrypt,
boolean disableClientEncryption,
- boolean aesEnable,
- Map<String, String> testConf)
+ Map<String, String> extraConf)
throws Exception {
+ Map<String, String> testConf = ImmutableMap.<String, String>builder()
+ .putAll(extraConf)
+ .put("spark.authenticate.enableSaslEncryption", String.valueOf(encrypt))
+ .build();
TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf));
- if (aesEnable) {
- conf = spy(conf);
- doReturn(true).when(conf).aesEncryptionEnabled();
- }
-
SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
when(keyHolder.getSaslUser(anyString())).thenReturn("user");
when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
TransportContext ctx = new TransportContext(conf, rpcHandler);
- String encryptHandlerName = aesEnable ? AesCipher.ENCRYPTION_HANDLER_NAME :
- SaslEncryption.ENCRYPTION_HANDLER_NAME;
-
- this.checker = new EncryptionCheckerBootstrap(encryptHandlerName);
+ this.checker = new EncryptionCheckerBootstrap(SaslEncryption.ENCRYPTION_HANDLER_NAME);
this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder),
checker));
try {
List<TransportClientBootstrap> clientBootstraps = Lists.newArrayList();
- clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt));
+ clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder));
if (disableClientEncryption) {
clientBootstraps.add(new EncryptionDisablerBootstrap());
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 772fb88325..616505d979 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -21,7 +21,6 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
-import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -30,7 +29,7 @@ import org.apache.spark.network.TransportContext;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
-import org.apache.spark.network.sasl.SaslClientBootstrap;
+import org.apache.spark.network.crypto.AuthClientBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
@@ -47,8 +46,7 @@ public class ExternalShuffleClient extends ShuffleClient {
private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class);
private final TransportConf conf;
- private final boolean saslEnabled;
- private final boolean saslEncryptionEnabled;
+ private final boolean authEnabled;
private final SecretKeyHolder secretKeyHolder;
protected TransportClientFactory clientFactory;
@@ -61,15 +59,10 @@ public class ExternalShuffleClient extends ShuffleClient {
public ExternalShuffleClient(
TransportConf conf,
SecretKeyHolder secretKeyHolder,
- boolean saslEnabled,
- boolean saslEncryptionEnabled) {
- Preconditions.checkArgument(
- !saslEncryptionEnabled || saslEnabled,
- "SASL encryption can only be enabled if SASL is also enabled.");
+ boolean authEnabled) {
this.conf = conf;
this.secretKeyHolder = secretKeyHolder;
- this.saslEnabled = saslEnabled;
- this.saslEncryptionEnabled = saslEncryptionEnabled;
+ this.authEnabled = authEnabled;
}
protected void checkInit() {
@@ -81,8 +74,8 @@ public class ExternalShuffleClient extends ShuffleClient {
this.appId = appId;
TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true);
List<TransportClientBootstrap> bootstraps = Lists.newArrayList();
- if (saslEnabled) {
- bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled));
+ if (authEnabled) {
+ bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder));
}
clientFactory = context.createClientFactory(bootstraps);
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
index 42cedd9943..ab49b1c1d7 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
@@ -60,9 +60,8 @@ public class MesosExternalShuffleClient extends ExternalShuffleClient {
public MesosExternalShuffleClient(
TransportConf conf,
SecretKeyHolder secretKeyHolder,
- boolean saslEnabled,
- boolean saslEncryptionEnabled) {
- super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled);
+ boolean authEnabled) {
+ super(conf, secretKeyHolder, authEnabled);
}
public void registerDriverWithShuffleService(
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index 8dd97b29eb..9248ef3c46 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -133,7 +133,7 @@ public class ExternalShuffleIntegrationSuite {
final Semaphore requestsRemaining = new Semaphore(0);
- ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, false);
+ ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false);
client.init(APP_ID);
client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
new BlockFetchingListener() {
@@ -243,7 +243,7 @@ public class ExternalShuffleIntegrationSuite {
private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo)
throws IOException {
- ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false);
+ ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
client.init(APP_ID);
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
executorId, executorInfo);
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
index aed25a161e..4ae75a1b17 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
@@ -20,6 +20,7 @@ package org.apache.spark.network.shuffle;
import java.io.IOException;
import java.util.Arrays;
+import com.google.common.collect.ImmutableMap;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -88,8 +89,14 @@ public class ExternalShuffleSecuritySuite {
/** Creates an ExternalShuffleClient and attempts to register with the server. */
private void validate(String appId, String secretKey, boolean encrypt) throws IOException {
+ TransportConf testConf = conf;
+ if (encrypt) {
+ testConf = new TransportConf("shuffle", new MapConfigProvider(
+ ImmutableMap.of("spark.authenticate.enableSaslEncryption", "true")));
+ }
+
ExternalShuffleClient client =
- new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true, encrypt);
+ new ExternalShuffleClient(testConf, new TestSecretKeyHolder(appId, secretKey), true);
client.init(appId);
// Registration either succeeds or throws an exception.
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0",
diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
index ea726e3c82..c7620d0fe1 100644
--- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
+++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
@@ -45,7 +45,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.TransportContext;
-import org.apache.spark.network.sasl.SaslServerBootstrap;
+import org.apache.spark.network.crypto.AuthServerBootstrap;
import org.apache.spark.network.sasl.ShuffleSecretManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
@@ -172,7 +172,7 @@ public class YarnShuffleService extends AuxiliaryService {
boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE);
if (authEnabled) {
createSecretManager();
- bootstraps.add(new SaslServerBootstrap(transportConf, secretManager));
+ bootstraps.add(new AuthServerBootstrap(transportConf, secretManager));
}
int port = conf.getInt(