aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java128
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java284
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java170
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java55
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java101
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/README.md158
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java85
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java (renamed from common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java)138
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java36
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java41
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java101
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java92
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java109
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java213
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java80
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java97
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java19
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java5
-rw-r--r--common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java4
-rw-r--r--common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java9
-rw-r--r--common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java4
-rw-r--r--core/src/main/scala/org/apache/spark/SecurityManager.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/SparkConf.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/internal/config/package.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/SparkConfSuite.scala19
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala54
-rw-r--r--docs/configuration.md50
-rw-r--r--resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala3
34 files changed, 1709 insertions, 422 deletions
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java
new file mode 100644
index 0000000000..980525dbf0
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.security.GeneralSecurityException;
+import java.security.Key;
+import javax.crypto.KeyGenerator;
+import javax.crypto.Mac;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.sasl.SaslClientBootstrap;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Bootstraps a {@link TransportClient} by performing authentication using Spark's auth protocol.
+ *
+ * This bootstrap falls back to using the SASL bootstrap if the server throws an error during
+ * authentication, and the configuration allows it. This is used for backwards compatibility
+ * with external shuffle services that do not support the new protocol.
+ *
+ * It also automatically falls back to SASL if the new encryption backend is disabled, so that
+ * callers only need to install this bootstrap when authentication is enabled.
+ */
+public class AuthClientBootstrap implements TransportClientBootstrap {
+
+ private static final Logger LOG = LoggerFactory.getLogger(AuthClientBootstrap.class);
+
+ private final TransportConf conf;
+ private final String appId;
+ private final String authUser;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public AuthClientBootstrap(
+ TransportConf conf,
+ String appId,
+ SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ // TODO: right now this behaves like the SASL backend, because when executors start up
+ // they don't necessarily know the app ID. So they send a hardcoded "user" that is defined
+ // in the SecurityManager, which will also always return the same secret (regardless of the
+ // user name). All that's needed here is for this "user" to match on both sides, since that's
+ // required by the protocol. At some point, though, it would be better for the actual app ID
+ // to be provided here.
+ this.appId = appId;
+ this.authUser = secretKeyHolder.getSaslUser(appId);
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ @Override
+ public void doBootstrap(TransportClient client, Channel channel) {
+ if (!conf.encryptionEnabled()) {
+ LOG.debug("AES encryption disabled, using old auth protocol.");
+ doSaslAuth(client, channel);
+ return;
+ }
+
+ try {
+ doSparkAuth(client, channel);
+ } catch (GeneralSecurityException | IOException e) {
+ throw Throwables.propagate(e);
+ } catch (RuntimeException e) {
+ // There isn't a good exception that can be caught here to know whether it's really
+ // OK to switch back to SASL (because the server doesn't speak the new protocol). So
+ // try it anyway, and in the worst case things will fail again.
+ if (conf.saslFallback()) {
+ LOG.warn("New auth protocol failed, trying SASL.", e);
+ doSaslAuth(client, channel);
+ } else {
+ throw e;
+ }
+ }
+ }
+
+ private void doSparkAuth(TransportClient client, Channel channel)
+ throws GeneralSecurityException, IOException {
+
+ AuthEngine engine = new AuthEngine(authUser, secretKeyHolder.getSecretKey(authUser), conf);
+ try {
+ ClientChallenge challenge = engine.challenge();
+ ByteBuf challengeData = Unpooled.buffer(challenge.encodedLength());
+ challenge.encode(challengeData);
+
+ ByteBuffer responseData = client.sendRpcSync(challengeData.nioBuffer(),
+ conf.authRTTimeoutMs());
+ ServerResponse response = ServerResponse.decodeMessage(responseData);
+
+ engine.validate(response);
+ engine.sessionCipher().addToChannel(channel);
+ } finally {
+ engine.close();
+ }
+ }
+
+ private void doSaslAuth(TransportClient client, Channel channel) {
+ SaslClientBootstrap sasl = new SaslClientBootstrap(conf, appId, secretKeyHolder);
+ sasl.doBootstrap(client, channel);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
new file mode 100644
index 0000000000..b769ebeba3
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
@@ -0,0 +1,284 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.math.BigInteger;
+import java.security.GeneralSecurityException;
+import java.util.Arrays;
+import java.util.Properties;
+import javax.crypto.Cipher;
+import javax.crypto.SecretKey;
+import javax.crypto.SecretKeyFactory;
+import javax.crypto.ShortBufferException;
+import javax.crypto.spec.IvParameterSpec;
+import javax.crypto.spec.PBEKeySpec;
+import javax.crypto.spec.SecretKeySpec;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.primitives.Bytes;
+import org.apache.commons.crypto.cipher.CryptoCipher;
+import org.apache.commons.crypto.cipher.CryptoCipherFactory;
+import org.apache.commons.crypto.random.CryptoRandom;
+import org.apache.commons.crypto.random.CryptoRandomFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A helper class for abstracting authentication and key negotiation details. This is used by
+ * both client and server sides, since the operations are basically the same.
+ */
+class AuthEngine implements Closeable {
+
+ private static final Logger LOG = LoggerFactory.getLogger(AuthEngine.class);
+ private static final BigInteger ONE = new BigInteger(new byte[] { 0x1 });
+
+ private final byte[] appId;
+ private final char[] secret;
+ private final TransportConf conf;
+ private final Properties cryptoConf;
+ private final CryptoRandom random;
+
+ private byte[] authNonce;
+
+ @VisibleForTesting
+ byte[] challenge;
+
+ private TransportCipher sessionCipher;
+ private CryptoCipher encryptor;
+ private CryptoCipher decryptor;
+
+ AuthEngine(String appId, String secret, TransportConf conf) throws GeneralSecurityException {
+ this.appId = appId.getBytes(UTF_8);
+ this.conf = conf;
+ this.cryptoConf = conf.cryptoConf();
+ this.secret = secret.toCharArray();
+ this.random = CryptoRandomFactory.getCryptoRandom(cryptoConf);
+ }
+
+ /**
+ * Create the client challenge.
+ *
+ * @return A challenge to be sent the remote side.
+ */
+ ClientChallenge challenge() throws GeneralSecurityException, IOException {
+ this.authNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
+ SecretKeySpec authKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(),
+ authNonce, conf.encryptionKeyLength());
+ initializeForAuth(conf.cipherTransformation(), authNonce, authKey);
+
+ this.challenge = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
+ return new ClientChallenge(new String(appId, UTF_8),
+ conf.keyFactoryAlgorithm(),
+ conf.keyFactoryIterations(),
+ conf.cipherTransformation(),
+ conf.encryptionKeyLength(),
+ authNonce,
+ challenge(appId, authNonce, challenge));
+ }
+
+ /**
+ * Validates the client challenge, and create the encryption backend for the channel from the
+ * parameters sent by the client.
+ *
+ * @param clientChallenge The challenge from the client.
+ * @return A response to be sent to the client.
+ */
+ ServerResponse respond(ClientChallenge clientChallenge)
+ throws GeneralSecurityException, IOException {
+
+ SecretKeySpec authKey = generateKey(clientChallenge.kdf, clientChallenge.iterations,
+ clientChallenge.nonce, clientChallenge.keyLength);
+ initializeForAuth(clientChallenge.cipher, clientChallenge.nonce, authKey);
+
+ byte[] challenge = validateChallenge(clientChallenge.nonce, clientChallenge.challenge);
+ byte[] response = challenge(appId, clientChallenge.nonce, rawResponse(challenge));
+ byte[] sessionNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
+ byte[] inputIv = randomBytes(conf.ivLength());
+ byte[] outputIv = randomBytes(conf.ivLength());
+
+ SecretKeySpec sessionKey = generateKey(clientChallenge.kdf, clientChallenge.iterations,
+ sessionNonce, clientChallenge.keyLength);
+ this.sessionCipher = new TransportCipher(cryptoConf, clientChallenge.cipher, sessionKey,
+ inputIv, outputIv);
+
+ // Note the IVs are swapped in the response.
+ return new ServerResponse(response, encrypt(sessionNonce), encrypt(outputIv), encrypt(inputIv));
+ }
+
+ /**
+ * Validates the server response and initializes the cipher to use for the session.
+ *
+ * @param serverResponse The response from the server.
+ */
+ void validate(ServerResponse serverResponse) throws GeneralSecurityException {
+ byte[] response = validateChallenge(authNonce, serverResponse.response);
+
+ byte[] expected = rawResponse(challenge);
+ Preconditions.checkArgument(Arrays.equals(expected, response));
+
+ byte[] nonce = decrypt(serverResponse.nonce);
+ byte[] inputIv = decrypt(serverResponse.inputIv);
+ byte[] outputIv = decrypt(serverResponse.outputIv);
+
+ SecretKeySpec sessionKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(),
+ nonce, conf.encryptionKeyLength());
+ this.sessionCipher = new TransportCipher(cryptoConf, conf.cipherTransformation(), sessionKey,
+ inputIv, outputIv);
+ }
+
+ TransportCipher sessionCipher() {
+ Preconditions.checkState(sessionCipher != null);
+ return sessionCipher;
+ }
+
+ @Override
+ public void close() throws IOException {
+ // Close ciphers (by calling "doFinal()" with dummy data) and the random instance so that
+ // internal state is cleaned up. Error handling here is just for paranoia, and not meant to
+ // accurately report the errors when they happen.
+ RuntimeException error = null;
+ byte[] dummy = new byte[8];
+ try {
+ doCipherOp(encryptor, dummy, true);
+ } catch (Exception e) {
+ error = new RuntimeException(e);
+ }
+ try {
+ doCipherOp(decryptor, dummy, true);
+ } catch (Exception e) {
+ error = new RuntimeException(e);
+ }
+ random.close();
+
+ if (error != null) {
+ throw error;
+ }
+ }
+
+ @VisibleForTesting
+ byte[] challenge(byte[] appId, byte[] nonce, byte[] challenge) throws GeneralSecurityException {
+ return encrypt(Bytes.concat(appId, nonce, challenge));
+ }
+
+ @VisibleForTesting
+ byte[] rawResponse(byte[] challenge) {
+ BigInteger orig = new BigInteger(challenge);
+ BigInteger response = orig.add(ONE);
+ return response.toByteArray();
+ }
+
+ private byte[] decrypt(byte[] in) throws GeneralSecurityException {
+ return doCipherOp(decryptor, in, false);
+ }
+
+ private byte[] encrypt(byte[] in) throws GeneralSecurityException {
+ return doCipherOp(encryptor, in, false);
+ }
+
+ private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key)
+ throws GeneralSecurityException {
+
+ // commons-crypto currently only supports ciphers that require an initial vector; so
+ // create a dummy vector so that we can initialize the ciphers. In the future, if
+ // different ciphers are supported, this will have to be configurable somehow.
+ byte[] iv = new byte[conf.ivLength()];
+ System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length));
+
+ encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
+ encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv));
+
+ decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
+ decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv));
+ }
+
+ /**
+ * Validates an encrypted challenge as defined in the protocol, and returns the byte array
+ * that corresponds to the actual challenge data.
+ */
+ private byte[] validateChallenge(byte[] nonce, byte[] encryptedChallenge)
+ throws GeneralSecurityException {
+
+ byte[] challenge = decrypt(encryptedChallenge);
+ checkSubArray(appId, challenge, 0);
+ checkSubArray(nonce, challenge, appId.length);
+ return Arrays.copyOfRange(challenge, appId.length + nonce.length, challenge.length);
+ }
+
+ private SecretKeySpec generateKey(String kdf, int iterations, byte[] salt, int keyLength)
+ throws GeneralSecurityException {
+
+ SecretKeyFactory factory = SecretKeyFactory.getInstance(kdf);
+ PBEKeySpec spec = new PBEKeySpec(secret, salt, iterations, keyLength);
+
+ long start = System.nanoTime();
+ SecretKey key = factory.generateSecret(spec);
+ long end = System.nanoTime();
+
+ LOG.debug("Generated key with {} iterations in {} us.", conf.keyFactoryIterations(),
+ (end - start) / 1000);
+
+ return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm());
+ }
+
+ private byte[] doCipherOp(CryptoCipher cipher, byte[] in, boolean isFinal)
+ throws GeneralSecurityException {
+
+ Preconditions.checkState(cipher != null);
+
+ int scale = 1;
+ while (true) {
+ int size = in.length * scale;
+ byte[] buffer = new byte[size];
+ try {
+ int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0)
+ : cipher.update(in, 0, in.length, buffer, 0);
+ if (outSize != buffer.length) {
+ byte[] output = new byte[outSize];
+ System.arraycopy(buffer, 0, output, 0, output.length);
+ return output;
+ } else {
+ return buffer;
+ }
+ } catch (ShortBufferException e) {
+ // Try again with a bigger buffer.
+ scale *= 2;
+ }
+ }
+ }
+
+ private byte[] randomBytes(int count) {
+ byte[] bytes = new byte[count];
+ random.nextBytes(bytes);
+ return bytes;
+ }
+
+ /** Checks that the "test" array is in the data array starting at the given offset. */
+ private void checkSubArray(byte[] test, byte[] data, int offset) {
+ Preconditions.checkArgument(data.length >= test.length + offset);
+ for (int i = 0; i < test.length; i++) {
+ Preconditions.checkArgument(test[i] == data[i + offset]);
+ }
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
new file mode 100644
index 0000000000..991d8ba95f
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import javax.security.sasl.Sasl;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Throwables;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.sasl.SaslRpcHandler;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * RPC Handler which performs authentication using Spark's auth protocol before delegating to a
+ * child RPC handler. If the configuration allows, this handler will delegate messages to a SASL
+ * RPC handler for further authentication, to support for clients that do not support Spark's
+ * protocol.
+ *
+ * The delegate will only receive messages if the given connection has been successfully
+ * authenticated. A connection may be authenticated at most once.
+ */
+class AuthRpcHandler extends RpcHandler {
+ private static final Logger LOG = LoggerFactory.getLogger(AuthRpcHandler.class);
+
+ /** Transport configuration. */
+ private final TransportConf conf;
+
+ /** The client channel. */
+ private final Channel channel;
+
+ /**
+ * RpcHandler we will delegate to for authenticated connections. When falling back to SASL
+ * this will be replaced with the SASL RPC handler.
+ */
+ @VisibleForTesting
+ RpcHandler delegate;
+
+ /** Class which provides secret keys which are shared by server and client on a per-app basis. */
+ private final SecretKeyHolder secretKeyHolder;
+
+ /** Whether auth is done and future calls should be delegated. */
+ @VisibleForTesting
+ boolean doDelegate;
+
+ AuthRpcHandler(
+ TransportConf conf,
+ Channel channel,
+ RpcHandler delegate,
+ SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.channel = channel;
+ this.delegate = delegate;
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
+ if (doDelegate) {
+ delegate.receive(client, message, callback);
+ return;
+ }
+
+ int position = message.position();
+ int limit = message.limit();
+
+ ClientChallenge challenge;
+ try {
+ challenge = ClientChallenge.decodeMessage(message);
+ LOG.debug("Received new auth challenge for client {}.", channel.remoteAddress());
+ } catch (RuntimeException e) {
+ if (conf.saslFallback()) {
+ LOG.warn("Failed to parse new auth challenge, reverting to SASL for client {}.",
+ channel.remoteAddress());
+ delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder);
+ message.position(position);
+ message.limit(limit);
+ delegate.receive(client, message, callback);
+ doDelegate = true;
+ } else {
+ LOG.debug("Unexpected challenge message from client {}, closing channel.",
+ channel.remoteAddress());
+ callback.onFailure(new IllegalArgumentException("Unknown challenge message."));
+ channel.close();
+ }
+ return;
+ }
+
+ // Here we have the client challenge, so perform the new auth protocol and set up the channel.
+ AuthEngine engine = null;
+ try {
+ engine = new AuthEngine(challenge.appId, secretKeyHolder.getSecretKey(challenge.appId), conf);
+ ServerResponse response = engine.respond(challenge);
+ ByteBuf responseData = Unpooled.buffer(response.encodedLength());
+ response.encode(responseData);
+ callback.onSuccess(responseData.nioBuffer());
+ engine.sessionCipher().addToChannel(channel);
+ } catch (Exception e) {
+ // This is a fatal error: authentication has failed. Close the channel explicitly.
+ LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress());
+ callback.onFailure(new IllegalArgumentException("Authentication failed."));
+ channel.close();
+ return;
+ } finally {
+ if (engine != null) {
+ try {
+ engine.close();
+ } catch (Exception e) {
+ throw Throwables.propagate(e);
+ }
+ }
+ }
+
+ LOG.debug("Authorization successful for client {}.", channel.remoteAddress());
+ doDelegate = true;
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message) {
+ delegate.receive(client, message);
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return delegate.getStreamManager();
+ }
+
+ @Override
+ public void channelActive(TransportClient client) {
+ delegate.channelActive(client);
+ }
+
+ @Override
+ public void channelInactive(TransportClient client) {
+ delegate.channelInactive(client);
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause, TransportClient client) {
+ delegate.exceptionCaught(cause, client);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java
new file mode 100644
index 0000000000..77a2a6af4d
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import io.netty.channel.Channel;
+
+import org.apache.spark.network.sasl.SaslServerBootstrap;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a client connects
+ * to the server, enabling authentication using Spark's auth protocol (and optionally SASL for
+ * clients that don't support the new protocol).
+ *
+ * It also automatically falls back to SASL if the new encryption backend is disabled, so that
+ * callers only need to install this bootstrap when authentication is enabled.
+ */
+public class AuthServerBootstrap implements TransportServerBootstrap {
+
+ private final TransportConf conf;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public AuthServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
+ if (!conf.encryptionEnabled()) {
+ TransportServerBootstrap sasl = new SaslServerBootstrap(conf, secretKeyHolder);
+ return sasl.doBootstrap(channel, rpcHandler);
+ }
+
+ return new AuthRpcHandler(conf, channel, rpcHandler, secretKeyHolder);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java
new file mode 100644
index 0000000000..3312a5bd81
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.nio.ByteBuffer;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.Encoders;
+
+/**
+ * The client challenge message, used to initiate authentication.
+ *
+ * @see README.md
+ */
+public class ClientChallenge implements Encodable {
+ /** Serialization tag used to catch incorrect payloads. */
+ private static final byte TAG_BYTE = (byte) 0xFA;
+
+ public final String appId;
+ public final String kdf;
+ public final int iterations;
+ public final String cipher;
+ public final int keyLength;
+ public final byte[] nonce;
+ public final byte[] challenge;
+
+ public ClientChallenge(
+ String appId,
+ String kdf,
+ int iterations,
+ String cipher,
+ int keyLength,
+ byte[] nonce,
+ byte[] challenge) {
+ this.appId = appId;
+ this.kdf = kdf;
+ this.iterations = iterations;
+ this.cipher = cipher;
+ this.keyLength = keyLength;
+ this.nonce = nonce;
+ this.challenge = challenge;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 1 + 4 + 4 +
+ Encoders.Strings.encodedLength(appId) +
+ Encoders.Strings.encodedLength(kdf) +
+ Encoders.Strings.encodedLength(cipher) +
+ Encoders.ByteArrays.encodedLength(nonce) +
+ Encoders.ByteArrays.encodedLength(challenge);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeByte(TAG_BYTE);
+ Encoders.Strings.encode(buf, appId);
+ Encoders.Strings.encode(buf, kdf);
+ buf.writeInt(iterations);
+ Encoders.Strings.encode(buf, cipher);
+ buf.writeInt(keyLength);
+ Encoders.ByteArrays.encode(buf, nonce);
+ Encoders.ByteArrays.encode(buf, challenge);
+ }
+
+ public static ClientChallenge decodeMessage(ByteBuffer buffer) {
+ ByteBuf buf = Unpooled.wrappedBuffer(buffer);
+
+ if (buf.readByte() != TAG_BYTE) {
+ throw new IllegalArgumentException("Expected ClientChallenge, received something else.");
+ }
+
+ return new ClientChallenge(
+ Encoders.Strings.decode(buf),
+ Encoders.Strings.decode(buf),
+ buf.readInt(),
+ Encoders.Strings.decode(buf),
+ buf.readInt(),
+ Encoders.ByteArrays.decode(buf),
+ Encoders.ByteArrays.decode(buf));
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md
new file mode 100644
index 0000000000..14df703270
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md
@@ -0,0 +1,158 @@
+Spark Auth Protocol and AES Encryption Support
+==============================================
+
+This file describes an auth protocol used by Spark as a more secure alternative to DIGEST-MD5. This
+protocol is built on symmetric key encryption, based on the assumption that the two endpoints being
+authenticated share a common secret, which is how Spark authentication currently works. The protocol
+provides mutual authentication, meaning that after the negotiation both parties know that the remote
+side knows the shared secret. The protocol is influenced by the ISO/IEC 9798 protocol, although it's
+not an implementation of it.
+
+This protocol could be replaced with TLS PSK, except no PSK ciphers are available in the currently
+released JREs.
+
+The protocol aims at solving the following shortcomings in Spark's current usage of DIGEST-MD5:
+
+- MD5 is an aging hash algorithm with known weaknesses, and a more secure alternative is desired.
+- DIGEST-MD5 has a pre-defined set of ciphers for which it can generate keys. The only
+ viable, supported cipher these days is 3DES, and a more modern alternative is desired.
+- Encrypting AES session keys with 3DES doesn't solve the issue, since the weakest link
+ in the negotiation would still be MD5 and 3DES.
+
+The protocol assumes that the shared secret is generated and distributed in a secure manner.
+
+The protocol always negotiates encryption keys. If encryption is not desired, the existing
+SASL-based authentication, or no authentication at all, can be chosen instead.
+
+When messages are described below, it's expected that the implementation should support
+arbitrary sizes for fields that don't have a fixed size.
+
+Client Challenge
+----------------
+
+The auth negotiation is started by the client. The client starts by generating an encryption
+key based on the application's shared secret, and a nonce.
+
+ KEY = KDF(SECRET, SALT, KEY_LENGTH)
+
+Where:
+- KDF(): a key derivation function that takes a secret, a salt, a configurable number of
+ iterations, and a configurable key length.
+- SALT: a byte sequence used to salt the key derivation function.
+- KEY_LENGTH: length of the encryption key to generate.
+
+
+The client generates a message with the following content:
+
+ CLIENT_CHALLENGE = (
+ APP_ID,
+ KDF,
+ ITERATIONS,
+ CIPHER,
+ KEY_LENGTH,
+ ANONCE,
+ ENC(APP_ID || ANONCE || CHALLENGE))
+
+Where:
+
+- APP_ID: the application ID which the server uses to identify the shared secret.
+- KDF: the key derivation function described above.
+- ITERATIONS: number of iterations to run the KDF when generating keys.
+- CIPHER: the cipher used to encrypt data.
+- KEY_LENGTH: length of the encryption keys to generate, in bits.
+- ANONCE: the nonce used as the salt when generating the auth key.
+- ENC(): an encryption function that uses the cipher and the generated key. This function
+ will also be used in the definition of other messages below.
+- CHALLENGE: a byte sequence used as a challenge to the server.
+- ||: concatenation operator.
+
+When strings are used where byte arrays are expected, the UTF-8 representation of the string
+is assumed.
+
+To respond to the challenge, the server should consider the byte array as representing an
+arbitrary-length integer, and respond with the value of the integer plus one.
+
+
+Server Response And Challenge
+-----------------------------
+
+Once the client challenge is received, the server will generate the same auth key by
+using the same algorithm the client has used. It will then verify the client challenge:
+if the APP_ID and ANONCE fields match, the server knows that the client has the shared
+secret. The server then creates a response to the client challenge, to prove that it also
+has the secret key, and provides parameters to be used when creating the session key.
+
+The following describes the response from the server:
+
+ SERVER_CHALLENGE = (
+ ENC(APP_ID || ANONCE || RESPONSE),
+ ENC(SNONCE),
+ ENC(INIV),
+ ENC(OUTIV))
+
+Where:
+
+- RESPONSE: the server's response to the client challenge.
+- SNONCE: a nonce to be used as salt when generating the session key.
+- INIV: initialization vector used to initialize the input channel of the client.
+- OUTIV: initialization vector used to initialize the output channel of the client.
+
+At this point the server considers the client to be authenticated, and will try to
+decrypt any data further sent by the client using the session key.
+
+
+Default Algorithms
+------------------
+
+Configuration options are available for the KDF and cipher algorithms to use.
+
+The default KDF is "PBKDF2WithHmacSHA1". Users should be able to select any algorithm
+from those supported by the `javax.crypto.SecretKeyFactory` class, as long as they support
+PBEKeySpec when generating keys. The default number of iterations was chosen to take a
+reasonable amount of time on modern CPUs. See the documentation in TransportConf for more
+details.
+
+The default cipher algorithm is "AES/CTR/NoPadding". Users should be able to select any
+algorithm supported by the commons-crypto library. It should allow the cipher to operate
+in stream mode.
+
+The default key length is 128 (bits).
+
+
+Implementation Details
+----------------------
+
+The commons-crypto library currently only supports AES ciphers, and requires an initialization
+vector (IV). This first version of the protocol does not explicitly include the IV in the client
+challenge message. Instead, the IV should be derived from the nonce, including the needed bytes, and
+padding the IV with zeroes in case the nonce is not long enough.
+
+Future versions of the protocol might add support for new ciphers and explicitly include needed
+configuration parameters in the messages.
+
+
+Threat Assessment
+-----------------
+
+The protocol is secure against different forms of attack:
+
+* Eavesdropping: the protocol is built on the assumption that it's computationally infeasible
+ to calculate the original secret from the encrypted messages. Neither the secret nor any
+ encryption keys are transmitted on the wire, encrypted or not.
+
+* Man-in-the-middle: because the protocol performs mutual authentication, both ends need to
+ know the shared secret to be able to decrypt session data. Even if an attacker is able to insert a
+ malicious "proxy" between endpoints, the attacker won't be able to read any of the data exchanged
+ between client and server, nor insert arbitrary commands for the server to execute.
+
+* Replay attacks: the use of nonces when generating keys prevents an attacker from being able to
+ just replay messages sniffed from the communication channel.
+
+An attacker may replay the client challenge and successfully "prove" to a server that it "knows" the
+shared secret. But the attacker won't be able to decrypt the server's response, and thus won't be
+able to generate a session key, which will make it hard to craft a valid, encrypted message that the
+server will be able to understand. This will cause the server to close the connection as soon as the
+attacker tries to send any command to the server. The attacker can just hold the channel open for
+some time, which will be closed when the server times out the channel. These issues could be
+separately mitigated by adding a shorter timeout for the first message after authentication, and
+potentially by adding host blacklists if a possible attack is detected from a particular host.
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java
new file mode 100644
index 0000000000..affdbf450b
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.nio.ByteBuffer;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.Encoders;
+
+/**
+ * Server's response to client's challenge.
+ *
+ * @see README.md
+ */
+public class ServerResponse implements Encodable {
+ /** Serialization tag used to catch incorrect payloads. */
+ private static final byte TAG_BYTE = (byte) 0xFB;
+
+ public final byte[] response;
+ public final byte[] nonce;
+ public final byte[] inputIv;
+ public final byte[] outputIv;
+
+ public ServerResponse(
+ byte[] response,
+ byte[] nonce,
+ byte[] inputIv,
+ byte[] outputIv) {
+ this.response = response;
+ this.nonce = nonce;
+ this.inputIv = inputIv;
+ this.outputIv = outputIv;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 1 +
+ Encoders.ByteArrays.encodedLength(response) +
+ Encoders.ByteArrays.encodedLength(nonce) +
+ Encoders.ByteArrays.encodedLength(inputIv) +
+ Encoders.ByteArrays.encodedLength(outputIv);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeByte(TAG_BYTE);
+ Encoders.ByteArrays.encode(buf, response);
+ Encoders.ByteArrays.encode(buf, nonce);
+ Encoders.ByteArrays.encode(buf, inputIv);
+ Encoders.ByteArrays.encode(buf, outputIv);
+ }
+
+ public static ServerResponse decodeMessage(ByteBuffer buffer) {
+ ByteBuf buf = Unpooled.wrappedBuffer(buffer);
+
+ if (buf.readByte() != TAG_BYTE) {
+ throw new IllegalArgumentException("Expected ServerResponse, received something else.");
+ }
+
+ return new ServerResponse(
+ Encoders.ByteArrays.decode(buf),
+ Encoders.ByteArrays.decode(buf),
+ Encoders.ByteArrays.decode(buf),
+ Encoders.ByteArrays.decode(buf));
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
index 340986a63b..7376d1ddc4 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.network.sasl.aes;
+package org.apache.spark.network.crypto;
import java.io.IOException;
import java.nio.ByteBuffer;
@@ -25,115 +25,91 @@ import java.util.Properties;
import javax.crypto.spec.SecretKeySpec;
import javax.crypto.spec.IvParameterSpec;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
-import com.google.common.base.Throwables;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.util.AbstractReferenceCounted;
-import org.apache.commons.crypto.cipher.CryptoCipherFactory;
-import org.apache.commons.crypto.random.CryptoRandom;
-import org.apache.commons.crypto.random.CryptoRandomFactory;
import org.apache.commons.crypto.stream.CryptoInputStream;
import org.apache.commons.crypto.stream.CryptoOutputStream;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
import org.apache.spark.network.util.ByteArrayReadableChannel;
import org.apache.spark.network.util.ByteArrayWritableChannel;
-import org.apache.spark.network.util.TransportConf;
/**
- * AES cipher for encryption and decryption.
+ * Cipher for encryption and decryption.
*/
-public class AesCipher {
- private static final Logger logger = LoggerFactory.getLogger(AesCipher.class);
- public static final String ENCRYPTION_HANDLER_NAME = "AesEncryption";
- public static final String DECRYPTION_HANDLER_NAME = "AesDecryption";
- public static final int STREAM_BUFFER_SIZE = 1024 * 32;
- public static final String TRANSFORM = "AES/CTR/NoPadding";
-
- private final SecretKeySpec inKeySpec;
- private final IvParameterSpec inIvSpec;
- private final SecretKeySpec outKeySpec;
- private final IvParameterSpec outIvSpec;
- private final Properties properties;
-
- public AesCipher(AesConfigMessage configMessage, TransportConf conf) throws IOException {
- this.properties = conf.cryptoConf();
- this.inKeySpec = new SecretKeySpec(configMessage.inKey, "AES");
- this.inIvSpec = new IvParameterSpec(configMessage.inIv);
- this.outKeySpec = new SecretKeySpec(configMessage.outKey, "AES");
- this.outIvSpec = new IvParameterSpec(configMessage.outIv);
+public class TransportCipher {
+ @VisibleForTesting
+ static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption";
+ private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption";
+ private static final int STREAM_BUFFER_SIZE = 1024 * 32;
+
+ private final Properties conf;
+ private final String cipher;
+ private final SecretKeySpec key;
+ private final byte[] inIv;
+ private final byte[] outIv;
+
+ public TransportCipher(
+ Properties conf,
+ String cipher,
+ SecretKeySpec key,
+ byte[] inIv,
+ byte[] outIv) {
+ this.conf = conf;
+ this.cipher = cipher;
+ this.key = key;
+ this.inIv = inIv;
+ this.outIv = outIv;
+ }
+
+ public String getCipherTransformation() {
+ return cipher;
+ }
+
+ @VisibleForTesting
+ SecretKeySpec getKey() {
+ return key;
+ }
+
+ /** The IV for the input channel (i.e. output channel of the remote side). */
+ public byte[] getInputIv() {
+ return inIv;
+ }
+
+ /** The IV for the output channel (i.e. input channel of the remote side). */
+ public byte[] getOutputIv() {
+ return outIv;
}
- /**
- * Create AES crypto output stream
- * @param ch The underlying channel to write out.
- * @return Return output crypto stream for encryption.
- * @throws IOException
- */
private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException {
- return new CryptoOutputStream(TRANSFORM, properties, ch, outKeySpec, outIvSpec);
+ return new CryptoOutputStream(cipher, conf, ch, key, new IvParameterSpec(outIv));
}
- /**
- * Create AES crypto input stream
- * @param ch The underlying channel used to read data.
- * @return Return input crypto stream for decryption.
- * @throws IOException
- */
private CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException {
- return new CryptoInputStream(TRANSFORM, properties, ch, inKeySpec, inIvSpec);
+ return new CryptoInputStream(cipher, conf, ch, key, new IvParameterSpec(inIv));
}
/**
- * Add handlers to channel
+ * Add handlers to channel.
+ *
* @param ch the channel for adding handlers
* @throws IOException
*/
public void addToChannel(Channel ch) throws IOException {
ch.pipeline()
- .addFirst(ENCRYPTION_HANDLER_NAME, new AesEncryptHandler(this))
- .addFirst(DECRYPTION_HANDLER_NAME, new AesDecryptHandler(this));
- }
-
- /**
- * Create the configuration message
- * @param conf is the local transport configuration.
- * @return Config message for sending.
- */
- public static AesConfigMessage createConfigMessage(TransportConf conf) {
- int keySize = conf.aesCipherKeySize();
- Properties properties = conf.cryptoConf();
-
- try {
- int paramLen = CryptoCipherFactory.getCryptoCipher(AesCipher.TRANSFORM, properties)
- .getBlockSize();
- byte[] inKey = new byte[keySize];
- byte[] outKey = new byte[keySize];
- byte[] inIv = new byte[paramLen];
- byte[] outIv = new byte[paramLen];
-
- CryptoRandom random = CryptoRandomFactory.getCryptoRandom(properties);
- random.nextBytes(inKey);
- random.nextBytes(outKey);
- random.nextBytes(inIv);
- random.nextBytes(outIv);
-
- return new AesConfigMessage(inKey, inIv, outKey, outIv);
- } catch (Exception e) {
- logger.error("AES config error", e);
- throw Throwables.propagate(e);
- }
+ .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(this))
+ .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this));
}
- private static class AesEncryptHandler extends ChannelOutboundHandlerAdapter {
+ private static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
private final ByteArrayWritableChannel byteChannel;
private final CryptoOutputStream cos;
- AesEncryptHandler(AesCipher cipher) throws IOException {
- byteChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE);
+ EncryptionHandler(TransportCipher cipher) throws IOException {
+ byteChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
cos = cipher.createOutputStream(byteChannel);
}
@@ -153,11 +129,11 @@ public class AesCipher {
}
}
- private static class AesDecryptHandler extends ChannelInboundHandlerAdapter {
+ private static class DecryptionHandler extends ChannelInboundHandlerAdapter {
private final CryptoInputStream cis;
private final ByteArrayReadableChannel byteChannel;
- AesDecryptHandler(AesCipher cipher) throws IOException {
+ DecryptionHandler(TransportCipher cipher) throws IOException {
byteChannel = new ByteArrayReadableChannel();
cis = cipher.createInputStream(byteChannel);
}
@@ -207,7 +183,7 @@ public class AesCipher {
this.buf = isByteBuf ? (ByteBuf) msg : null;
this.region = isByteBuf ? null : (FileRegion) msg;
this.transferred = 0;
- this.byteRawChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE);
+ this.byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
this.cos = cos;
this.byteEncChannel = ch;
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index a1bb453657..6478137722 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -30,8 +30,6 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
-import org.apache.spark.network.sasl.aes.AesCipher;
-import org.apache.spark.network.sasl.aes.AesConfigMessage;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.TransportConf;
@@ -42,24 +40,14 @@ import org.apache.spark.network.util.TransportConf;
public class SaslClientBootstrap implements TransportClientBootstrap {
private static final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);
- private final boolean encrypt;
private final TransportConf conf;
private final String appId;
private final SecretKeyHolder secretKeyHolder;
public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
- this(conf, appId, secretKeyHolder, false);
- }
-
- public SaslClientBootstrap(
- TransportConf conf,
- String appId,
- SecretKeyHolder secretKeyHolder,
- boolean encrypt) {
this.conf = conf;
this.appId = appId;
this.secretKeyHolder = secretKeyHolder;
- this.encrypt = encrypt;
}
/**
@@ -69,7 +57,7 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
*/
@Override
public void doBootstrap(TransportClient client, Channel channel) {
- SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt);
+ SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption());
try {
byte[] payload = saslClient.firstToken();
@@ -79,35 +67,19 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
msg.encode(buf);
buf.writeBytes(msg.body().nioByteBuffer());
- ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs());
+ ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs());
payload = saslClient.response(JavaUtils.bufferToArray(response));
}
client.setClientId(appId);
- if (encrypt) {
+ if (conf.saslEncryption()) {
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
throw new RuntimeException(
new SaslException("Encryption requests by negotiated non-encrypted connection."));
}
- if (conf.aesEncryptionEnabled()) {
- // Generate a request config message to send to server.
- AesConfigMessage configMessage = AesCipher.createConfigMessage(conf);
- ByteBuffer buf = configMessage.encodeMessage();
-
- // Encrypted the config message.
- byte[] toEncrypt = JavaUtils.bufferToArray(buf);
- ByteBuffer encrypted = ByteBuffer.wrap(saslClient.wrap(toEncrypt, 0, toEncrypt.length));
-
- client.sendRpcSync(encrypted, conf.saslRTTimeoutMs());
- AesCipher cipher = new AesCipher(configMessage, conf);
- logger.info("Enabling AES cipher for client channel {}", client);
- cipher.addToChannel(channel);
- saslClient.dispose();
- } else {
- SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
- }
+ SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
saslClient = null;
logger.debug("Channel {} configured for encryption.", client);
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index b2f3ef214b..0231428318 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -29,8 +29,6 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.sasl.aes.AesCipher;
-import org.apache.spark.network.sasl.aes.AesConfigMessage;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.JavaUtils;
@@ -44,7 +42,7 @@ import org.apache.spark.network.util.TransportConf;
* Note that the authentication process consists of multiple challenge-response pairs, each of
* which are individual RPCs.
*/
-class SaslRpcHandler extends RpcHandler {
+public class SaslRpcHandler extends RpcHandler {
private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
/** Transport configuration. */
@@ -63,7 +61,7 @@ class SaslRpcHandler extends RpcHandler {
private boolean isComplete;
private boolean isAuthenticated;
- SaslRpcHandler(
+ public SaslRpcHandler(
TransportConf conf,
Channel channel,
RpcHandler delegate,
@@ -122,37 +120,10 @@ class SaslRpcHandler extends RpcHandler {
return;
}
- if (!conf.aesEncryptionEnabled()) {
- logger.debug("Enabling encryption for channel {}", client);
- SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
- complete(false);
- return;
- }
-
- // Extra negotiation should happen after authentication, so return directly while
- // processing authenticate.
- if (!isAuthenticated) {
- logger.debug("SASL authentication successful for channel {}", client);
- isAuthenticated = true;
- return;
- }
-
- // Create AES cipher when it is authenticated
- try {
- byte[] encrypted = JavaUtils.bufferToArray(message);
- ByteBuffer decrypted = ByteBuffer.wrap(saslServer.unwrap(encrypted, 0 , encrypted.length));
-
- AesConfigMessage configMessage = AesConfigMessage.decodeMessage(decrypted);
- AesCipher cipher = new AesCipher(configMessage, conf);
-
- // Send response back to client to confirm that server accept config.
- callback.onSuccess(JavaUtils.stringToBytes(AesCipher.TRANSFORM));
- logger.info("Enabling AES cipher for Server channel {}", client);
- cipher.addToChannel(channel);
- complete(true);
- } catch (IOException ioe) {
- throw new RuntimeException(ioe);
- }
+ logger.debug("Enabling encryption for channel {}", client);
+ SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
+ complete(false);
+ return;
}
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java
deleted file mode 100644
index 3ef6f74a1f..0000000000
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.sasl.aes;
-
-import java.nio.ByteBuffer;
-
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-
-import org.apache.spark.network.protocol.Encodable;
-import org.apache.spark.network.protocol.Encoders;
-
-/**
- * The AES cipher options for encryption negotiation.
- */
-public class AesConfigMessage implements Encodable {
- /** Serialization tag used to catch incorrect payloads. */
- private static final byte TAG_BYTE = (byte) 0xEB;
-
- public byte[] inKey;
- public byte[] outKey;
- public byte[] inIv;
- public byte[] outIv;
-
- public AesConfigMessage(byte[] inKey, byte[] inIv, byte[] outKey, byte[] outIv) {
- if (inKey == null || inIv == null || outKey == null || outIv == null) {
- throw new IllegalArgumentException("Cipher Key or IV must not be null!");
- }
-
- this.inKey = inKey;
- this.inIv = inIv;
- this.outKey = outKey;
- this.outIv = outIv;
- }
-
- @Override
- public int encodedLength() {
- return 1 +
- Encoders.ByteArrays.encodedLength(inKey) + Encoders.ByteArrays.encodedLength(outKey) +
- Encoders.ByteArrays.encodedLength(inIv) + Encoders.ByteArrays.encodedLength(outIv);
- }
-
- @Override
- public void encode(ByteBuf buf) {
- buf.writeByte(TAG_BYTE);
- Encoders.ByteArrays.encode(buf, inKey);
- Encoders.ByteArrays.encode(buf, inIv);
- Encoders.ByteArrays.encode(buf, outKey);
- Encoders.ByteArrays.encode(buf, outIv);
- }
-
- /**
- * Encode the config message.
- * @return ByteBuffer which contains encoded config message.
- */
- public ByteBuffer encodeMessage(){
- ByteBuffer buf = ByteBuffer.allocate(encodedLength());
-
- ByteBuf wrappedBuf = Unpooled.wrappedBuffer(buf);
- wrappedBuf.clear();
- encode(wrappedBuf);
-
- return buf;
- }
-
- /**
- * Decode the config message from buffer
- * @param buffer the buffer contain encoded config message
- * @return config message
- */
- public static AesConfigMessage decodeMessage(ByteBuffer buffer) {
- ByteBuf buf = Unpooled.wrappedBuffer(buffer);
-
- if (buf.readByte() != TAG_BYTE) {
- throw new IllegalStateException("Expected AesConfigMessage, received something else"
- + " (maybe your client does not have AES enabled?)");
- }
-
- byte[] outKey = Encoders.ByteArrays.decode(buf);
- byte[] outIv = Encoders.ByteArrays.decode(buf);
- byte[] inKey = Encoders.ByteArrays.decode(buf);
- byte[] inIv = Encoders.ByteArrays.decode(buf);
- return new AesConfigMessage(inKey, inIv, outKey, outIv);
- }
-
-}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 6a557fa75d..c226d8f3bc 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -117,9 +117,10 @@ public class TransportConf {
/** Send buffer size (SO_SNDBUF). */
public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); }
- /** Timeout for a single round trip of SASL token exchange, in milliseconds. */
- public int saslRTTimeoutMs() {
- return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000;
+ /** Timeout for a single round trip of auth message exchange, in milliseconds. */
+ public int authRTTimeoutMs() {
+ return (int) JavaUtils.timeStringAsSec(conf.get("spark.network.auth.rpcTimeout",
+ conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s"))) * 1000;
}
/**
@@ -162,40 +163,95 @@ public class TransportConf {
}
/**
- * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled.
+ * Enables strong encryption. Also enables the new auth protocol, used to negotiate keys.
*/
- public int maxSaslEncryptedBlockSize() {
- return Ints.checkedCast(JavaUtils.byteStringAsBytes(
- conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k")));
+ public boolean encryptionEnabled() {
+ return conf.getBoolean("spark.network.crypto.enabled", false);
}
/**
- * Whether the server should enforce encryption on SASL-authenticated connections.
+ * The cipher transformation to use for encrypting session data.
*/
- public boolean saslServerAlwaysEncrypt() {
- return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
+ public String cipherTransformation() {
+ return conf.get("spark.network.crypto.cipher", "AES/CTR/NoPadding");
+ }
+
+ /**
+ * The key generation algorithm. This should be an algorithm that accepts a "PBEKeySpec"
+ * as input. The default value (PBKDF2WithHmacSHA1) is available in Java 7.
+ */
+ public String keyFactoryAlgorithm() {
+ return conf.get("spark.network.crypto.keyFactoryAlgorithm", "PBKDF2WithHmacSHA1");
+ }
+
+ /**
+ * How many iterations to run when generating keys.
+ *
+ * See some discussion about this at: http://security.stackexchange.com/q/3959
+ * The default value was picked for speed, since it assumes that the secret has good entropy
+ * (128 bits by default), which is not generally the case with user passwords.
+ */
+ public int keyFactoryIterations() {
+ return conf.getInt("spark.networy.crypto.keyFactoryIterations", 1024);
+ }
+
+ /**
+ * Encryption key length, in bits.
+ */
+ public int encryptionKeyLength() {
+ return conf.getInt("spark.network.crypto.keyLength", 128);
+ }
+
+ /**
+ * Initial vector length, in bytes.
+ */
+ public int ivLength() {
+ return conf.getInt("spark.network.crypto.ivLength", 16);
+ }
+
+ /**
+ * The algorithm for generated secret keys. Nobody should really need to change this,
+ * but configurable just in case.
+ */
+ public String keyAlgorithm() {
+ return conf.get("spark.network.crypto.keyAlgorithm", "AES");
+ }
+
+ /**
+ * Whether to fall back to SASL if the new auth protocol fails. Enabled by default for
+ * backwards compatibility.
+ */
+ public boolean saslFallback() {
+ return conf.getBoolean("spark.network.crypto.saslFallback", true);
}
/**
- * The trigger for enabling AES encryption.
+ * Whether to enable SASL-based encryption when authenticating using SASL.
*/
- public boolean aesEncryptionEnabled() {
- return conf.getBoolean("spark.network.aes.enabled", false);
+ public boolean saslEncryption() {
+ return conf.getBoolean("spark.authenticate.enableSaslEncryption", false);
}
/**
- * The key size to use when AES cipher is enabled. Notice that the length should be 16, 24 or 32
- * bytes.
+ * Maximum number of bytes to be encrypted at a time when SASL encryption is used.
*/
- public int aesCipherKeySize() {
- return conf.getInt("spark.network.aes.keySize", 16);
+ public int maxSaslEncryptedBlockSize() {
+ return Ints.checkedCast(JavaUtils.byteStringAsBytes(
+ conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k")));
+ }
+
+ /**
+ * Whether the server should enforce encryption on SASL-authenticated connections.
+ */
+ public boolean saslServerAlwaysEncrypt() {
+ return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
}
/**
* The commons-crypto configuration for the module.
*/
public Properties cryptoConf() {
- return CryptoUtils.toCryptoConf("spark.network.aes.config.", conf.getAll());
+ return CryptoUtils.toCryptoConf("spark.network.crypto.config.", conf.getAll());
}
}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
new file mode 100644
index 0000000000..9a186f2113
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.util.Arrays;
+import java.util.Map;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+import com.google.common.collect.ImmutableMap;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.util.MapConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class AuthEngineSuite {
+
+ private static TransportConf conf;
+
+ @BeforeClass
+ public static void setUp() {
+ conf = new TransportConf("rpc", MapConfigProvider.EMPTY);
+ }
+
+ @Test
+ public void testAuthEngine() throws Exception {
+ AuthEngine client = new AuthEngine("appId", "secret", conf);
+ AuthEngine server = new AuthEngine("appId", "secret", conf);
+
+ try {
+ ClientChallenge clientChallenge = client.challenge();
+ ServerResponse serverResponse = server.respond(clientChallenge);
+ client.validate(serverResponse);
+
+ TransportCipher serverCipher = server.sessionCipher();
+ TransportCipher clientCipher = client.sessionCipher();
+
+ assertTrue(Arrays.equals(serverCipher.getInputIv(), clientCipher.getOutputIv()));
+ assertTrue(Arrays.equals(serverCipher.getOutputIv(), clientCipher.getInputIv()));
+ assertEquals(serverCipher.getKey(), clientCipher.getKey());
+ } finally {
+ client.close();
+ server.close();
+ }
+ }
+
+ @Test
+ public void testMismatchedSecret() throws Exception {
+ AuthEngine client = new AuthEngine("appId", "secret", conf);
+ AuthEngine server = new AuthEngine("appId", "different_secret", conf);
+
+ ClientChallenge clientChallenge = client.challenge();
+ try {
+ server.respond(clientChallenge);
+ fail("Should have failed to validate response.");
+ } catch (IllegalArgumentException e) {
+ // Expected.
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testWrongAppId() throws Exception {
+ AuthEngine engine = new AuthEngine("appId", "secret", conf);
+ ClientChallenge challenge = engine.challenge();
+
+ byte[] badChallenge = engine.challenge(new byte[] { 0x00 }, challenge.nonce,
+ engine.rawResponse(engine.challenge));
+ engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
+ challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testWrongNonce() throws Exception {
+ AuthEngine engine = new AuthEngine("appId", "secret", conf);
+ ClientChallenge challenge = engine.challenge();
+
+ byte[] badChallenge = engine.challenge(challenge.appId.getBytes(UTF_8), new byte[] { 0x00 },
+ engine.rawResponse(engine.challenge));
+ engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
+ challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testBadChallenge() throws Exception {
+ AuthEngine engine = new AuthEngine("appId", "secret", conf);
+ ClientChallenge challenge = engine.challenge();
+
+ byte[] badChallenge = new byte[challenge.challenge.length];
+ engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
+ challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
+ }
+
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
new file mode 100644
index 0000000000..21609d5aa2
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import io.netty.channel.Channel;
+import org.junit.After;
+import org.junit.Test;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.sasl.SaslRpcHandler;
+import org.apache.spark.network.sasl.SaslServerBootstrap;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.MapConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class AuthIntegrationSuite {
+
+ private AuthTestCtx ctx;
+
+ @After
+ public void cleanUp() throws Exception {
+ if (ctx != null) {
+ ctx.close();
+ }
+ ctx = null;
+ }
+
+ @Test
+ public void testNewAuth() throws Exception {
+ ctx = new AuthTestCtx();
+ ctx.createServer("secret");
+ ctx.createClient("secret");
+
+ ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
+ assertEquals("Pong", JavaUtils.bytesToString(reply));
+ assertTrue(ctx.authRpcHandler.doDelegate);
+ assertFalse(ctx.authRpcHandler.delegate instanceof SaslRpcHandler);
+ }
+
+ @Test
+ public void testAuthFailure() throws Exception {
+ ctx = new AuthTestCtx();
+ ctx.createServer("server");
+
+ try {
+ ctx.createClient("client");
+ fail("Should have failed to create client.");
+ } catch (Exception e) {
+ assertFalse(ctx.authRpcHandler.doDelegate);
+ assertFalse(ctx.serverChannel.isActive());
+ }
+ }
+
+ @Test
+ public void testSaslServerFallback() throws Exception {
+ ctx = new AuthTestCtx();
+ ctx.createServer("secret", true);
+ ctx.createClient("secret", false);
+
+ ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
+ assertEquals("Pong", JavaUtils.bytesToString(reply));
+ }
+
+ @Test
+ public void testSaslClientFallback() throws Exception {
+ ctx = new AuthTestCtx();
+ ctx.createServer("secret", false);
+ ctx.createClient("secret", true);
+
+ ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
+ assertEquals("Pong", JavaUtils.bytesToString(reply));
+ }
+
+ @Test
+ public void testAuthReplay() throws Exception {
+ // This test covers the case where an attacker replays a challenge message sniffed from the
+ // network, but doesn't know the actual secret. The server should close the connection as
+ // soon as a message is sent after authentication is performed. This is emulated by removing
+ // the client encryption handler after authentication.
+ ctx = new AuthTestCtx();
+ ctx.createServer("secret");
+ ctx.createClient("secret");
+
+ assertNotNull(ctx.client.getChannel().pipeline()
+ .remove(TransportCipher.ENCRYPTION_HANDLER_NAME));
+
+ try {
+ ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
+ fail("Should have failed unencrypted RPC.");
+ } catch (Exception e) {
+ assertTrue(ctx.authRpcHandler.doDelegate);
+ }
+ }
+
+ private class AuthTestCtx {
+
+ private final String appId = "testAppId";
+ private final TransportConf conf;
+ private final TransportContext ctx;
+
+ TransportClient client;
+ TransportServer server;
+ volatile Channel serverChannel;
+ volatile AuthRpcHandler authRpcHandler;
+
+ AuthTestCtx() throws Exception {
+ Map<String, String> testConf = ImmutableMap.of("spark.network.crypto.enabled", "true");
+ this.conf = new TransportConf("rpc", new MapConfigProvider(testConf));
+
+ RpcHandler rpcHandler = new RpcHandler() {
+ @Override
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ assertEquals("Ping", JavaUtils.bytesToString(message));
+ callback.onSuccess(JavaUtils.stringToBytes("Pong"));
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return null;
+ }
+ };
+
+ this.ctx = new TransportContext(conf, rpcHandler);
+ }
+
+ void createServer(String secret) throws Exception {
+ createServer(secret, true);
+ }
+
+ void createServer(String secret, boolean enableAes) throws Exception {
+ TransportServerBootstrap introspector = new TransportServerBootstrap() {
+ @Override
+ public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
+ AuthTestCtx.this.serverChannel = channel;
+ if (rpcHandler instanceof AuthRpcHandler) {
+ AuthTestCtx.this.authRpcHandler = (AuthRpcHandler) rpcHandler;
+ }
+ return rpcHandler;
+ }
+ };
+ SecretKeyHolder keyHolder = createKeyHolder(secret);
+ TransportServerBootstrap auth = enableAes ? new AuthServerBootstrap(conf, keyHolder)
+ : new SaslServerBootstrap(conf, keyHolder);
+ this.server = ctx.createServer(Lists.newArrayList(auth, introspector));
+ }
+
+ void createClient(String secret) throws Exception {
+ createClient(secret, true);
+ }
+
+ void createClient(String secret, boolean enableAes) throws Exception {
+ TransportConf clientConf = enableAes ? conf
+ : new TransportConf("rpc", MapConfigProvider.EMPTY);
+ List<TransportClientBootstrap> bootstraps = Lists.<TransportClientBootstrap>newArrayList(
+ new AuthClientBootstrap(clientConf, appId, createKeyHolder(secret)));
+ this.client = ctx.createClientFactory(bootstraps)
+ .createClient(TestUtils.getLocalHost(), server.getPort());
+ }
+
+ void close() {
+ if (client != null) {
+ client.close();
+ }
+ if (server != null) {
+ server.close();
+ }
+ }
+
+ private SecretKeyHolder createKeyHolder(String secret) {
+ SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
+ when(keyHolder.getSaslUser(anyString())).thenReturn(appId);
+ when(keyHolder.getSecretKey(anyString())).thenReturn(secret);
+ return keyHolder;
+ }
+
+ }
+
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java
new file mode 100644
index 0000000000..a90ff247da
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.protocol.Encodable;
+
+public class AuthMessagesSuite {
+
+ private static int COUNTER = 0;
+
+ private static String string() {
+ return String.valueOf(COUNTER++);
+ }
+
+ private static byte[] byteArray() {
+ byte[] bytes = new byte[COUNTER++];
+ for (int i = 0; i < bytes.length; i++) {
+ bytes[i] = (byte) COUNTER;
+ } return bytes;
+ }
+
+ private static int integer() {
+ return COUNTER++;
+ }
+
+ @Test
+ public void testClientChallenge() {
+ ClientChallenge msg = new ClientChallenge(string(), string(), integer(), string(), integer(),
+ byteArray(), byteArray());
+ ClientChallenge decoded = ClientChallenge.decodeMessage(encode(msg));
+
+ assertEquals(msg.appId, decoded.appId);
+ assertEquals(msg.kdf, decoded.kdf);
+ assertEquals(msg.iterations, decoded.iterations);
+ assertEquals(msg.cipher, decoded.cipher);
+ assertEquals(msg.keyLength, decoded.keyLength);
+ assertTrue(Arrays.equals(msg.nonce, decoded.nonce));
+ assertTrue(Arrays.equals(msg.challenge, decoded.challenge));
+ }
+
+ @Test
+ public void testServerResponse() {
+ ServerResponse msg = new ServerResponse(byteArray(), byteArray(), byteArray(), byteArray());
+ ServerResponse decoded = ServerResponse.decodeMessage(encode(msg));
+ assertTrue(Arrays.equals(msg.response, decoded.response));
+ assertTrue(Arrays.equals(msg.nonce, decoded.nonce));
+ assertTrue(Arrays.equals(msg.inputIv, decoded.inputIv));
+ assertTrue(Arrays.equals(msg.outputIv, decoded.outputIv));
+ }
+
+ private ByteBuffer encode(Encodable msg) {
+ ByteBuf buf = Unpooled.buffer();
+ msg.encode(buf);
+ return buf.nioBuffer();
+ }
+
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index e27301f49e..87129b900b 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -56,7 +56,6 @@ import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
-import org.apache.spark.network.sasl.aes.AesCipher;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
@@ -153,7 +152,7 @@ public class SparkSaslSuite {
.when(rpcHandler)
.receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class));
- SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false, false);
+ SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
try {
ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
TimeUnit.SECONDS.toMillis(10));
@@ -279,7 +278,7 @@ public class SparkSaslSuite {
new Random().nextBytes(data);
Files.write(data, file);
- ctx = new SaslTestCtx(rpcHandler, true, false, false, testConf);
+ ctx = new SaslTestCtx(rpcHandler, true, false, testConf);
final CountDownLatch lock = new CountDownLatch(1);
@@ -317,7 +316,7 @@ public class SparkSaslSuite {
public void testServerAlwaysEncrypt() throws Exception {
SaslTestCtx ctx = null;
try {
- ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, false,
+ ctx = new SaslTestCtx(mock(RpcHandler.class), false, false,
ImmutableMap.of("spark.network.sasl.serverAlwaysEncrypt", "true"));
fail("Should have failed to connect without encryption.");
} catch (Exception e) {
@@ -336,7 +335,7 @@ public class SparkSaslSuite {
// able to understand RPCs sent to it and thus close the connection.
SaslTestCtx ctx = null;
try {
- ctx = new SaslTestCtx(mock(RpcHandler.class), true, true, false);
+ ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
TimeUnit.SECONDS.toMillis(10));
fail("Should have failed to send RPC to server.");
@@ -374,69 +373,6 @@ public class SparkSaslSuite {
}
}
- @Test
- public void testAesEncryption() throws Exception {
- final AtomicReference<ManagedBuffer> response = new AtomicReference<>();
- final File file = File.createTempFile("sasltest", ".txt");
- SaslTestCtx ctx = null;
- try {
- final TransportConf conf = new TransportConf("rpc", MapConfigProvider.EMPTY);
- final TransportConf spyConf = spy(conf);
- doReturn(true).when(spyConf).aesEncryptionEnabled();
-
- StreamManager sm = mock(StreamManager.class);
- when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer<ManagedBuffer>() {
- @Override
- public ManagedBuffer answer(InvocationOnMock invocation) {
- return new FileSegmentManagedBuffer(spyConf, file, 0, file.length());
- }
- });
-
- RpcHandler rpcHandler = mock(RpcHandler.class);
- when(rpcHandler.getStreamManager()).thenReturn(sm);
-
- byte[] data = new byte[256 * 1024 * 1024];
- new Random().nextBytes(data);
- Files.write(data, file);
-
- ctx = new SaslTestCtx(rpcHandler, true, false, true);
-
- final Object lock = new Object();
-
- ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
- doAnswer(new Answer<Void>() {
- @Override
- public Void answer(InvocationOnMock invocation) {
- response.set((ManagedBuffer) invocation.getArguments()[1]);
- response.get().retain();
- synchronized (lock) {
- lock.notifyAll();
- }
- return null;
- }
- }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
-
- synchronized (lock) {
- ctx.client.fetchChunk(0, 0, callback);
- lock.wait(10 * 1000);
- }
-
- verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
- verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
-
- byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
- assertTrue(Arrays.equals(data, received));
- } finally {
- file.delete();
- if (ctx != null) {
- ctx.close();
- }
- if (response.get() != null) {
- response.get().release();
- }
- }
- }
-
private static class SaslTestCtx {
final TransportClient client;
@@ -449,46 +385,39 @@ public class SparkSaslSuite {
SaslTestCtx(
RpcHandler rpcHandler,
boolean encrypt,
- boolean disableClientEncryption,
- boolean aesEnable)
+ boolean disableClientEncryption)
throws Exception {
- this(rpcHandler, encrypt, disableClientEncryption, aesEnable,
- Collections.<String, String>emptyMap());
+ this(rpcHandler, encrypt, disableClientEncryption, Collections.<String, String>emptyMap());
}
SaslTestCtx(
RpcHandler rpcHandler,
boolean encrypt,
boolean disableClientEncryption,
- boolean aesEnable,
- Map<String, String> testConf)
+ Map<String, String> extraConf)
throws Exception {
+ Map<String, String> testConf = ImmutableMap.<String, String>builder()
+ .putAll(extraConf)
+ .put("spark.authenticate.enableSaslEncryption", String.valueOf(encrypt))
+ .build();
TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf));
- if (aesEnable) {
- conf = spy(conf);
- doReturn(true).when(conf).aesEncryptionEnabled();
- }
-
SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
when(keyHolder.getSaslUser(anyString())).thenReturn("user");
when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
TransportContext ctx = new TransportContext(conf, rpcHandler);
- String encryptHandlerName = aesEnable ? AesCipher.ENCRYPTION_HANDLER_NAME :
- SaslEncryption.ENCRYPTION_HANDLER_NAME;
-
- this.checker = new EncryptionCheckerBootstrap(encryptHandlerName);
+ this.checker = new EncryptionCheckerBootstrap(SaslEncryption.ENCRYPTION_HANDLER_NAME);
this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder),
checker));
try {
List<TransportClientBootstrap> clientBootstraps = Lists.newArrayList();
- clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt));
+ clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder));
if (disableClientEncryption) {
clientBootstraps.add(new EncryptionDisablerBootstrap());
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 772fb88325..616505d979 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -21,7 +21,6 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
-import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -30,7 +29,7 @@ import org.apache.spark.network.TransportContext;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
-import org.apache.spark.network.sasl.SaslClientBootstrap;
+import org.apache.spark.network.crypto.AuthClientBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
@@ -47,8 +46,7 @@ public class ExternalShuffleClient extends ShuffleClient {
private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class);
private final TransportConf conf;
- private final boolean saslEnabled;
- private final boolean saslEncryptionEnabled;
+ private final boolean authEnabled;
private final SecretKeyHolder secretKeyHolder;
protected TransportClientFactory clientFactory;
@@ -61,15 +59,10 @@ public class ExternalShuffleClient extends ShuffleClient {
public ExternalShuffleClient(
TransportConf conf,
SecretKeyHolder secretKeyHolder,
- boolean saslEnabled,
- boolean saslEncryptionEnabled) {
- Preconditions.checkArgument(
- !saslEncryptionEnabled || saslEnabled,
- "SASL encryption can only be enabled if SASL is also enabled.");
+ boolean authEnabled) {
this.conf = conf;
this.secretKeyHolder = secretKeyHolder;
- this.saslEnabled = saslEnabled;
- this.saslEncryptionEnabled = saslEncryptionEnabled;
+ this.authEnabled = authEnabled;
}
protected void checkInit() {
@@ -81,8 +74,8 @@ public class ExternalShuffleClient extends ShuffleClient {
this.appId = appId;
TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true);
List<TransportClientBootstrap> bootstraps = Lists.newArrayList();
- if (saslEnabled) {
- bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled));
+ if (authEnabled) {
+ bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder));
}
clientFactory = context.createClientFactory(bootstraps);
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
index 42cedd9943..ab49b1c1d7 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
@@ -60,9 +60,8 @@ public class MesosExternalShuffleClient extends ExternalShuffleClient {
public MesosExternalShuffleClient(
TransportConf conf,
SecretKeyHolder secretKeyHolder,
- boolean saslEnabled,
- boolean saslEncryptionEnabled) {
- super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled);
+ boolean authEnabled) {
+ super(conf, secretKeyHolder, authEnabled);
}
public void registerDriverWithShuffleService(
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index 8dd97b29eb..9248ef3c46 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -133,7 +133,7 @@ public class ExternalShuffleIntegrationSuite {
final Semaphore requestsRemaining = new Semaphore(0);
- ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, false);
+ ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false);
client.init(APP_ID);
client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
new BlockFetchingListener() {
@@ -243,7 +243,7 @@ public class ExternalShuffleIntegrationSuite {
private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo)
throws IOException {
- ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false);
+ ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
client.init(APP_ID);
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
executorId, executorInfo);
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
index aed25a161e..4ae75a1b17 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
@@ -20,6 +20,7 @@ package org.apache.spark.network.shuffle;
import java.io.IOException;
import java.util.Arrays;
+import com.google.common.collect.ImmutableMap;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -88,8 +89,14 @@ public class ExternalShuffleSecuritySuite {
/** Creates an ExternalShuffleClient and attempts to register with the server. */
private void validate(String appId, String secretKey, boolean encrypt) throws IOException {
+ TransportConf testConf = conf;
+ if (encrypt) {
+ testConf = new TransportConf("shuffle", new MapConfigProvider(
+ ImmutableMap.of("spark.authenticate.enableSaslEncryption", "true")));
+ }
+
ExternalShuffleClient client =
- new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true, encrypt);
+ new ExternalShuffleClient(testConf, new TestSecretKeyHolder(appId, secretKey), true);
client.init(appId);
// Registration either succeeds or throws an exception.
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0",
diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
index ea726e3c82..c7620d0fe1 100644
--- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
+++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
@@ -45,7 +45,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.TransportContext;
-import org.apache.spark.network.sasl.SaslServerBootstrap;
+import org.apache.spark.network.crypto.AuthServerBootstrap;
import org.apache.spark.network.sasl.ShuffleSecretManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
@@ -172,7 +172,7 @@ public class YarnShuffleService extends AuxiliaryService {
boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE);
if (authEnabled) {
createSecretManager();
- bootstraps.add(new SaslServerBootstrap(transportConf, secretManager));
+ bootstraps.add(new AuthServerBootstrap(transportConf, secretManager));
}
int port = conf.getInt(
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 9bdc5096b6..cde768281f 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -29,6 +29,7 @@ import org.apache.hadoop.io.Text
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
import org.apache.spark.network.sasl.SecretKeyHolder
import org.apache.spark.util.Utils
@@ -191,7 +192,7 @@ private[spark] class SecurityManager(
// allow all users/groups to have view/modify permissions
private val WILDCARD_ACL = "*"
- private val authOn = sparkConf.getBoolean(SecurityManager.SPARK_AUTH_CONF, false)
+ private val authOn = sparkConf.get(NETWORK_AUTH_ENABLED)
// keep spark.ui.acls.enable for backwards compatibility with 1.0
private var aclsOn =
sparkConf.getBoolean("spark.acls.enable", sparkConf.getBoolean("spark.ui.acls.enable", false))
@@ -516,11 +517,11 @@ private[spark] class SecurityManager(
def isAuthenticationEnabled(): Boolean = authOn
/**
- * Checks whether SASL encryption should be enabled.
- * @return Whether to enable SASL encryption when connecting to services that support it.
+ * Checks whether network encryption should be enabled.
+ * @return Whether to enable encryption when connecting to services that support it.
*/
- def isSaslEncryptionEnabled(): Boolean = {
- sparkConf.getBoolean("spark.authenticate.enableSaslEncryption", false)
+ def isEncryptionEnabled(): Boolean = {
+ sparkConf.get(NETWORK_ENCRYPTION_ENABLED) || sparkConf.get(SASL_ENCRYPTION_ENABLED)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 601d24191e..308a1ed5fa 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -607,6 +607,10 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria
"\"client\".")
}
}
+
+ val encryptionEnabled = get(NETWORK_ENCRYPTION_ENABLED) || get(SASL_ENCRYPTION_ENABLED)
+ require(!encryptionEnabled || get(NETWORK_AUTH_ENABLED),
+ s"${NETWORK_AUTH_ENABLED.key} must be enabled when enabling encryption.")
}
/**
@@ -726,6 +730,7 @@ private[spark] object SparkConf extends Logging {
(name.startsWith("spark.auth") && name != SecurityManager.SPARK_AUTH_SECRET_CONF) ||
name.startsWith("spark.ssl") ||
name.startsWith("spark.rpc") ||
+ name.startsWith("spark.network") ||
isSparkPortConf(name)
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 1296386ac9..539dbb55ee 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -235,7 +235,7 @@ object SparkEnv extends Logging {
val securityManager = new SecurityManager(conf, ioEncryptionKey)
ioEncryptionKey.foreach { _ =>
- if (!securityManager.isSaslEncryptionEnabled()) {
+ if (!securityManager.isEncryptionEnabled()) {
logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " +
"wire.")
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
index 13eadbe44f..8d491ddf6e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
@@ -25,8 +25,8 @@ import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.TransportContext
+import org.apache.spark.network.crypto.AuthServerBootstrap
import org.apache.spark.network.netty.SparkTransportConf
-import org.apache.spark.network.sasl.SaslServerBootstrap
import org.apache.spark.network.server.{TransportServer, TransportServerBootstrap}
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
import org.apache.spark.network.util.TransportConf
@@ -47,7 +47,6 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false)
private val port = sparkConf.getInt("spark.shuffle.service.port", 7337)
- private val useSasl: Boolean = securityManager.isAuthenticationEnabled()
private val transportConf =
SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0)
@@ -74,10 +73,11 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
/** Start the external shuffle service */
def start() {
require(server == null, "Shuffle server already started")
- logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl")
+ val authEnabled = securityManager.isAuthenticationEnabled()
+ logInfo(s"Starting shuffle service on port $port (auth enabled = $authEnabled)")
val bootstraps: Seq[TransportServerBootstrap] =
- if (useSasl) {
- Seq(new SaslServerBootstrap(transportConf, securityManager))
+ if (authEnabled) {
+ Seq(new AuthServerBootstrap(transportConf, securityManager))
} else {
Nil
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index aba429bcdc..536f493b41 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -243,4 +243,20 @@ package object config {
"and event logs.")
.stringConf
.createWithDefault("(?i)secret|password")
+
+ private[spark] val NETWORK_AUTH_ENABLED =
+ ConfigBuilder("spark.authenticate")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val SASL_ENCRYPTION_ENABLED =
+ ConfigBuilder("spark.authenticate.enableSaslEncryption")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val NETWORK_ENCRYPTION_ENABLED =
+ ConfigBuilder("spark.network.crypto.enabled")
+ .booleanConf
+ .createWithDefault(false)
+
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 3d4ea3cccc..b75e91b660 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -27,7 +27,7 @@ import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory}
-import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
+import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher}
import org.apache.spark.network.shuffle.protocol.UploadBlock
@@ -63,9 +63,8 @@ private[spark] class NettyBlockTransferService(
var serverBootstrap: Option[TransportServerBootstrap] = None
var clientBootstrap: Option[TransportClientBootstrap] = None
if (authEnabled) {
- serverBootstrap = Some(new SaslServerBootstrap(transportConf, securityManager))
- clientBootstrap = Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager,
- securityManager.isSaslEncryptionEnabled()))
+ serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager))
+ clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager))
}
transportContext = new TransportContext(transportConf, rpcHandler)
clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava)
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index e56943da13..1e448b2f1a 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -33,8 +33,8 @@ import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.network.TransportContext
import org.apache.spark.network.client._
+import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap}
import org.apache.spark.network.netty.SparkTransportConf
-import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.rpc._
import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance}
@@ -60,8 +60,8 @@ private[netty] class NettyRpcEnv(
private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = {
if (securityManager.isAuthenticationEnabled()) {
- java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,
- securityManager.isSaslEncryptionEnabled()))
+ java.util.Arrays.asList(new AuthClientBootstrap(transportConf,
+ securityManager.getSaslUser(), securityManager))
} else {
java.util.Collections.emptyList[TransportClientBootstrap]
}
@@ -111,7 +111,7 @@ private[netty] class NettyRpcEnv(
def startServer(bindAddress: String, port: Int): Unit = {
val bootstraps: java.util.List[TransportServerBootstrap] =
if (securityManager.isAuthenticationEnabled()) {
- java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager))
+ java.util.Arrays.asList(new AuthServerBootstrap(transportConf, securityManager))
} else {
java.util.Collections.emptyList()
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 04521c9159..c40186756f 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -125,8 +125,7 @@ private[spark] class BlockManager(
// standard BlockTransferService to directly connect to other Executors.
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores)
- new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(),
- securityManager.isSaslEncryptionEnabled())
+ new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())
} else {
blockTransferService
}
diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
index 83906cff12..0897891ee1 100644
--- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
@@ -303,6 +303,25 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
}
}
+ test("encryption requires authentication") {
+ val conf = new SparkConf()
+ conf.validateSettings()
+
+ conf.set(NETWORK_ENCRYPTION_ENABLED, true)
+ intercept[IllegalArgumentException] {
+ conf.validateSettings()
+ }
+
+ conf.set(NETWORK_ENCRYPTION_ENABLED, false)
+ conf.set(SASL_ENCRYPTION_ENABLED, true)
+ intercept[IllegalArgumentException] {
+ conf.validateSettings()
+ }
+
+ conf.set(NETWORK_AUTH_ENABLED, true)
+ conf.validateSettings()
+ }
+
}
class Class1 {}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
index 022fe91eda..fe8955840d 100644
--- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
@@ -94,6 +94,20 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi
}
}
+ test("security with aes encryption") {
+ val conf = new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ .set("spark.network.crypto.enabled", "true")
+ .set("spark.network.crypto.saslFallback", "false")
+ testConnection(conf, conf) match {
+ case Success(_) => // expected
+ case Failure(t) => fail(t)
+ }
+ }
+
+
/**
* Creates two servers with different configurations and sees if they can talk.
* Returns Success() if they can transfer a block, and Failure() if the block transfer was failed
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index acdf21df9a..b4037d7a9c 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -637,11 +637,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(anotherEnv.address.port != env.address.port)
}
- test("send with authentication") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "true")
- conf.set("spark.authenticate.secret", "good")
-
+ private def testSend(conf: SparkConf): Unit = {
val localEnv = createRpcEnv(conf, "authentication-local", 0)
val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true)
@@ -667,11 +663,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
}
- test("ask with authentication") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "true")
- conf.set("spark.authenticate.secret", "good")
-
+ private def testAsk(conf: SparkConf): Unit = {
val localEnv = createRpcEnv(conf, "authentication-local", 0)
val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true)
@@ -695,6 +687,48 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
}
+ test("send with authentication") {
+ testSend(new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good"))
+ }
+
+ test("send with SASL encryption") {
+ testSend(new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.authenticate.enableSaslEncryption", "true"))
+ }
+
+ test("send with AES encryption") {
+ testSend(new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.network.crypto.enabled", "true")
+ .set("spark.network.crypto.saslFallback", "false"))
+ }
+
+ test("ask with authentication") {
+ testAsk(new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good"))
+ }
+
+ test("ask with SASL encryption") {
+ testAsk(new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.authenticate.enableSaslEncryption", "true"))
+ }
+
+ test("ask with AES encryption") {
+ testAsk(new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.network.crypto.enabled", "true")
+ .set("spark.network.crypto.saslFallback", "false"))
+ }
+
test("construct RpcTimeout with conf property") {
val conf = new SparkConf
diff --git a/docs/configuration.md b/docs/configuration.md
index b7f10e69f3..7c040330db 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1639,40 +1639,40 @@ Apart from these, the following properties are also available, and may be useful
</td>
</tr>
<tr>
- <td><code>spark.authenticate.enableSaslEncryption</code></td>
+ <td><code>spark.network.crypto.enabled</code></td>
<td>false</td>
<td>
- Enable encrypted communication when authentication is
- enabled. This is supported by the block transfer service and the
- RPC endpoints.
+ Enable encryption using the commons-crypto library for RPC and block transfer service.
+ Requires <code>spark.authenticate</code> to be enabled.
</td>
</tr>
<tr>
- <td><code>spark.network.sasl.serverAlwaysEncrypt</code></td>
- <td>false</td>
+ <td><code>spark.network.crypto.keyLength</code></td>
+ <td>128</td>
<td>
- Disable unencrypted connections for services that support SASL authentication. This is
- currently supported by the external shuffle service.
+ The length in bits of the encryption key to generate. Valid values are 128, 192 and 256.
</td>
</tr>
<tr>
- <td><code>spark.network.aes.enabled</code></td>
- <td>false</td>
+ <td><code>spark.network.crypto.keyFactoryAlgorithm</code></td>
+ <td>PBKDF2WithHmacSHA1</td>
<td>
- Enable AES for over-the-wire encryption. This is supported for RPC and the block transfer service.
- This option has precedence over SASL-based encryption if both are enabled.
+ The key factory algorithm to use when generating encryption keys. Should be one of the
+ algorithms supported by the javax.crypto.SecretKeyFactory class in the JRE being used.
</td>
</tr>
<tr>
- <td><code>spark.network.aes.keySize</code></td>
- <td>16</td>
+ <td><code>spark.network.crypto.saslFallback</code></td>
+ <td>true</td>
<td>
- The bytes of AES cipher key which is effective when AES cipher is enabled. AES
- works with 16, 24 and 32 bytes keys.
+ Whether to fall back to SASL authentication if authentication fails using Spark's internal
+ mechanism. This is useful when the application is connecting to old shuffle services that
+ do not support the internal Spark authentication protocol. On the server side, this can be
+ used to block older clients from authenticating against a new shuffle service.
</td>
</tr>
<tr>
- <td><code>spark.network.aes.config.*</code></td>
+ <td><code>spark.network.crypto.config.*</code></td>
<td>None</td>
<td>
Configuration values for the commons-crypto library, such as which cipher implementations to
@@ -1681,6 +1681,22 @@ Apart from these, the following properties are also available, and may be useful
</td>
</tr>
<tr>
+ <td><code>spark.authenticate.enableSaslEncryption</code></td>
+ <td>false</td>
+ <td>
+ Enable encrypted communication when authentication is
+ enabled. This is supported by the block transfer service and the
+ RPC endpoints.
+ </td>
+</tr>
+<tr>
+ <td><code>spark.network.sasl.serverAlwaysEncrypt</code></td>
+ <td>false</td>
+ <td>
+ Disable unencrypted connections for services that support SASL authentication.
+ </td>
+</tr>
+<tr>
<td><code>spark.core.connection.ack.wait.timeout</code></td>
<td><code>spark.network.timeout</code></td>
<td>
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
index 3258b09c06..f555072c38 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
@@ -136,8 +136,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
new MesosExternalShuffleClient(
SparkTransportConf.fromSparkConf(conf, "shuffle"),
securityManager,
- securityManager.isAuthenticationEnabled(),
- securityManager.isSaslEncryptionEnabled())
+ securityManager.isAuthenticationEnabled())
}
var nextMesosTaskId = 0