aboutsummaryrefslogtreecommitdiff
path: root/common/network-common/src/main/java/org/apache/spark
diff options
context:
space:
mode:
Diffstat (limited to 'common/network-common/src/main/java/org/apache/spark')
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java128
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java284
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java170
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java55
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java101
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/README.md158
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java85
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java (renamed from common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java)138
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java36
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java41
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java101
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java92
12 files changed, 1122 insertions, 267 deletions
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java
new file mode 100644
index 0000000000..980525dbf0
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.security.GeneralSecurityException;
+import java.security.Key;
+import javax.crypto.KeyGenerator;
+import javax.crypto.Mac;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.sasl.SaslClientBootstrap;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Bootstraps a {@link TransportClient} by performing authentication using Spark's auth protocol.
+ *
+ * This bootstrap falls back to using the SASL bootstrap if the server throws an error during
+ * authentication, and the configuration allows it. This is used for backwards compatibility
+ * with external shuffle services that do not support the new protocol.
+ *
+ * It also automatically falls back to SASL if the new encryption backend is disabled, so that
+ * callers only need to install this bootstrap when authentication is enabled.
+ */
+public class AuthClientBootstrap implements TransportClientBootstrap {
+
+ private static final Logger LOG = LoggerFactory.getLogger(AuthClientBootstrap.class);
+
+ private final TransportConf conf;
+ private final String appId;
+ private final String authUser;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public AuthClientBootstrap(
+ TransportConf conf,
+ String appId,
+ SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ // TODO: right now this behaves like the SASL backend, because when executors start up
+ // they don't necessarily know the app ID. So they send a hardcoded "user" that is defined
+ // in the SecurityManager, which will also always return the same secret (regardless of the
+ // user name). All that's needed here is for this "user" to match on both sides, since that's
+ // required by the protocol. At some point, though, it would be better for the actual app ID
+ // to be provided here.
+ this.appId = appId;
+ this.authUser = secretKeyHolder.getSaslUser(appId);
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ @Override
+ public void doBootstrap(TransportClient client, Channel channel) {
+ if (!conf.encryptionEnabled()) {
+ LOG.debug("AES encryption disabled, using old auth protocol.");
+ doSaslAuth(client, channel);
+ return;
+ }
+
+ try {
+ doSparkAuth(client, channel);
+ } catch (GeneralSecurityException | IOException e) {
+ throw Throwables.propagate(e);
+ } catch (RuntimeException e) {
+ // There isn't a good exception that can be caught here to know whether it's really
+ // OK to switch back to SASL (because the server doesn't speak the new protocol). So
+ // try it anyway, and in the worst case things will fail again.
+ if (conf.saslFallback()) {
+ LOG.warn("New auth protocol failed, trying SASL.", e);
+ doSaslAuth(client, channel);
+ } else {
+ throw e;
+ }
+ }
+ }
+
+ private void doSparkAuth(TransportClient client, Channel channel)
+ throws GeneralSecurityException, IOException {
+
+ AuthEngine engine = new AuthEngine(authUser, secretKeyHolder.getSecretKey(authUser), conf);
+ try {
+ ClientChallenge challenge = engine.challenge();
+ ByteBuf challengeData = Unpooled.buffer(challenge.encodedLength());
+ challenge.encode(challengeData);
+
+ ByteBuffer responseData = client.sendRpcSync(challengeData.nioBuffer(),
+ conf.authRTTimeoutMs());
+ ServerResponse response = ServerResponse.decodeMessage(responseData);
+
+ engine.validate(response);
+ engine.sessionCipher().addToChannel(channel);
+ } finally {
+ engine.close();
+ }
+ }
+
+ private void doSaslAuth(TransportClient client, Channel channel) {
+ SaslClientBootstrap sasl = new SaslClientBootstrap(conf, appId, secretKeyHolder);
+ sasl.doBootstrap(client, channel);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
new file mode 100644
index 0000000000..b769ebeba3
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
@@ -0,0 +1,284 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.math.BigInteger;
+import java.security.GeneralSecurityException;
+import java.util.Arrays;
+import java.util.Properties;
+import javax.crypto.Cipher;
+import javax.crypto.SecretKey;
+import javax.crypto.SecretKeyFactory;
+import javax.crypto.ShortBufferException;
+import javax.crypto.spec.IvParameterSpec;
+import javax.crypto.spec.PBEKeySpec;
+import javax.crypto.spec.SecretKeySpec;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.primitives.Bytes;
+import org.apache.commons.crypto.cipher.CryptoCipher;
+import org.apache.commons.crypto.cipher.CryptoCipherFactory;
+import org.apache.commons.crypto.random.CryptoRandom;
+import org.apache.commons.crypto.random.CryptoRandomFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A helper class for abstracting authentication and key negotiation details. This is used by
+ * both client and server sides, since the operations are basically the same.
+ */
+class AuthEngine implements Closeable {
+
+ private static final Logger LOG = LoggerFactory.getLogger(AuthEngine.class);
+ private static final BigInteger ONE = new BigInteger(new byte[] { 0x1 });
+
+ private final byte[] appId;
+ private final char[] secret;
+ private final TransportConf conf;
+ private final Properties cryptoConf;
+ private final CryptoRandom random;
+
+ private byte[] authNonce;
+
+ @VisibleForTesting
+ byte[] challenge;
+
+ private TransportCipher sessionCipher;
+ private CryptoCipher encryptor;
+ private CryptoCipher decryptor;
+
+ AuthEngine(String appId, String secret, TransportConf conf) throws GeneralSecurityException {
+ this.appId = appId.getBytes(UTF_8);
+ this.conf = conf;
+ this.cryptoConf = conf.cryptoConf();
+ this.secret = secret.toCharArray();
+ this.random = CryptoRandomFactory.getCryptoRandom(cryptoConf);
+ }
+
+ /**
+ * Create the client challenge.
+ *
+ * @return A challenge to be sent the remote side.
+ */
+ ClientChallenge challenge() throws GeneralSecurityException, IOException {
+ this.authNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
+ SecretKeySpec authKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(),
+ authNonce, conf.encryptionKeyLength());
+ initializeForAuth(conf.cipherTransformation(), authNonce, authKey);
+
+ this.challenge = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
+ return new ClientChallenge(new String(appId, UTF_8),
+ conf.keyFactoryAlgorithm(),
+ conf.keyFactoryIterations(),
+ conf.cipherTransformation(),
+ conf.encryptionKeyLength(),
+ authNonce,
+ challenge(appId, authNonce, challenge));
+ }
+
+ /**
+ * Validates the client challenge, and create the encryption backend for the channel from the
+ * parameters sent by the client.
+ *
+ * @param clientChallenge The challenge from the client.
+ * @return A response to be sent to the client.
+ */
+ ServerResponse respond(ClientChallenge clientChallenge)
+ throws GeneralSecurityException, IOException {
+
+ SecretKeySpec authKey = generateKey(clientChallenge.kdf, clientChallenge.iterations,
+ clientChallenge.nonce, clientChallenge.keyLength);
+ initializeForAuth(clientChallenge.cipher, clientChallenge.nonce, authKey);
+
+ byte[] challenge = validateChallenge(clientChallenge.nonce, clientChallenge.challenge);
+ byte[] response = challenge(appId, clientChallenge.nonce, rawResponse(challenge));
+ byte[] sessionNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
+ byte[] inputIv = randomBytes(conf.ivLength());
+ byte[] outputIv = randomBytes(conf.ivLength());
+
+ SecretKeySpec sessionKey = generateKey(clientChallenge.kdf, clientChallenge.iterations,
+ sessionNonce, clientChallenge.keyLength);
+ this.sessionCipher = new TransportCipher(cryptoConf, clientChallenge.cipher, sessionKey,
+ inputIv, outputIv);
+
+ // Note the IVs are swapped in the response.
+ return new ServerResponse(response, encrypt(sessionNonce), encrypt(outputIv), encrypt(inputIv));
+ }
+
+ /**
+ * Validates the server response and initializes the cipher to use for the session.
+ *
+ * @param serverResponse The response from the server.
+ */
+ void validate(ServerResponse serverResponse) throws GeneralSecurityException {
+ byte[] response = validateChallenge(authNonce, serverResponse.response);
+
+ byte[] expected = rawResponse(challenge);
+ Preconditions.checkArgument(Arrays.equals(expected, response));
+
+ byte[] nonce = decrypt(serverResponse.nonce);
+ byte[] inputIv = decrypt(serverResponse.inputIv);
+ byte[] outputIv = decrypt(serverResponse.outputIv);
+
+ SecretKeySpec sessionKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(),
+ nonce, conf.encryptionKeyLength());
+ this.sessionCipher = new TransportCipher(cryptoConf, conf.cipherTransformation(), sessionKey,
+ inputIv, outputIv);
+ }
+
+ TransportCipher sessionCipher() {
+ Preconditions.checkState(sessionCipher != null);
+ return sessionCipher;
+ }
+
+ @Override
+ public void close() throws IOException {
+ // Close ciphers (by calling "doFinal()" with dummy data) and the random instance so that
+ // internal state is cleaned up. Error handling here is just for paranoia, and not meant to
+ // accurately report the errors when they happen.
+ RuntimeException error = null;
+ byte[] dummy = new byte[8];
+ try {
+ doCipherOp(encryptor, dummy, true);
+ } catch (Exception e) {
+ error = new RuntimeException(e);
+ }
+ try {
+ doCipherOp(decryptor, dummy, true);
+ } catch (Exception e) {
+ error = new RuntimeException(e);
+ }
+ random.close();
+
+ if (error != null) {
+ throw error;
+ }
+ }
+
+ @VisibleForTesting
+ byte[] challenge(byte[] appId, byte[] nonce, byte[] challenge) throws GeneralSecurityException {
+ return encrypt(Bytes.concat(appId, nonce, challenge));
+ }
+
+ @VisibleForTesting
+ byte[] rawResponse(byte[] challenge) {
+ BigInteger orig = new BigInteger(challenge);
+ BigInteger response = orig.add(ONE);
+ return response.toByteArray();
+ }
+
+ private byte[] decrypt(byte[] in) throws GeneralSecurityException {
+ return doCipherOp(decryptor, in, false);
+ }
+
+ private byte[] encrypt(byte[] in) throws GeneralSecurityException {
+ return doCipherOp(encryptor, in, false);
+ }
+
+ private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key)
+ throws GeneralSecurityException {
+
+ // commons-crypto currently only supports ciphers that require an initial vector; so
+ // create a dummy vector so that we can initialize the ciphers. In the future, if
+ // different ciphers are supported, this will have to be configurable somehow.
+ byte[] iv = new byte[conf.ivLength()];
+ System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length));
+
+ encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
+ encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv));
+
+ decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
+ decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv));
+ }
+
+ /**
+ * Validates an encrypted challenge as defined in the protocol, and returns the byte array
+ * that corresponds to the actual challenge data.
+ */
+ private byte[] validateChallenge(byte[] nonce, byte[] encryptedChallenge)
+ throws GeneralSecurityException {
+
+ byte[] challenge = decrypt(encryptedChallenge);
+ checkSubArray(appId, challenge, 0);
+ checkSubArray(nonce, challenge, appId.length);
+ return Arrays.copyOfRange(challenge, appId.length + nonce.length, challenge.length);
+ }
+
+ private SecretKeySpec generateKey(String kdf, int iterations, byte[] salt, int keyLength)
+ throws GeneralSecurityException {
+
+ SecretKeyFactory factory = SecretKeyFactory.getInstance(kdf);
+ PBEKeySpec spec = new PBEKeySpec(secret, salt, iterations, keyLength);
+
+ long start = System.nanoTime();
+ SecretKey key = factory.generateSecret(spec);
+ long end = System.nanoTime();
+
+ LOG.debug("Generated key with {} iterations in {} us.", conf.keyFactoryIterations(),
+ (end - start) / 1000);
+
+ return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm());
+ }
+
+ private byte[] doCipherOp(CryptoCipher cipher, byte[] in, boolean isFinal)
+ throws GeneralSecurityException {
+
+ Preconditions.checkState(cipher != null);
+
+ int scale = 1;
+ while (true) {
+ int size = in.length * scale;
+ byte[] buffer = new byte[size];
+ try {
+ int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0)
+ : cipher.update(in, 0, in.length, buffer, 0);
+ if (outSize != buffer.length) {
+ byte[] output = new byte[outSize];
+ System.arraycopy(buffer, 0, output, 0, output.length);
+ return output;
+ } else {
+ return buffer;
+ }
+ } catch (ShortBufferException e) {
+ // Try again with a bigger buffer.
+ scale *= 2;
+ }
+ }
+ }
+
+ private byte[] randomBytes(int count) {
+ byte[] bytes = new byte[count];
+ random.nextBytes(bytes);
+ return bytes;
+ }
+
+ /** Checks that the "test" array is in the data array starting at the given offset. */
+ private void checkSubArray(byte[] test, byte[] data, int offset) {
+ Preconditions.checkArgument(data.length >= test.length + offset);
+ for (int i = 0; i < test.length; i++) {
+ Preconditions.checkArgument(test[i] == data[i + offset]);
+ }
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
new file mode 100644
index 0000000000..991d8ba95f
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import javax.security.sasl.Sasl;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Throwables;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.sasl.SaslRpcHandler;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * RPC Handler which performs authentication using Spark's auth protocol before delegating to a
+ * child RPC handler. If the configuration allows, this handler will delegate messages to a SASL
+ * RPC handler for further authentication, to support for clients that do not support Spark's
+ * protocol.
+ *
+ * The delegate will only receive messages if the given connection has been successfully
+ * authenticated. A connection may be authenticated at most once.
+ */
+class AuthRpcHandler extends RpcHandler {
+ private static final Logger LOG = LoggerFactory.getLogger(AuthRpcHandler.class);
+
+ /** Transport configuration. */
+ private final TransportConf conf;
+
+ /** The client channel. */
+ private final Channel channel;
+
+ /**
+ * RpcHandler we will delegate to for authenticated connections. When falling back to SASL
+ * this will be replaced with the SASL RPC handler.
+ */
+ @VisibleForTesting
+ RpcHandler delegate;
+
+ /** Class which provides secret keys which are shared by server and client on a per-app basis. */
+ private final SecretKeyHolder secretKeyHolder;
+
+ /** Whether auth is done and future calls should be delegated. */
+ @VisibleForTesting
+ boolean doDelegate;
+
+ AuthRpcHandler(
+ TransportConf conf,
+ Channel channel,
+ RpcHandler delegate,
+ SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.channel = channel;
+ this.delegate = delegate;
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
+ if (doDelegate) {
+ delegate.receive(client, message, callback);
+ return;
+ }
+
+ int position = message.position();
+ int limit = message.limit();
+
+ ClientChallenge challenge;
+ try {
+ challenge = ClientChallenge.decodeMessage(message);
+ LOG.debug("Received new auth challenge for client {}.", channel.remoteAddress());
+ } catch (RuntimeException e) {
+ if (conf.saslFallback()) {
+ LOG.warn("Failed to parse new auth challenge, reverting to SASL for client {}.",
+ channel.remoteAddress());
+ delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder);
+ message.position(position);
+ message.limit(limit);
+ delegate.receive(client, message, callback);
+ doDelegate = true;
+ } else {
+ LOG.debug("Unexpected challenge message from client {}, closing channel.",
+ channel.remoteAddress());
+ callback.onFailure(new IllegalArgumentException("Unknown challenge message."));
+ channel.close();
+ }
+ return;
+ }
+
+ // Here we have the client challenge, so perform the new auth protocol and set up the channel.
+ AuthEngine engine = null;
+ try {
+ engine = new AuthEngine(challenge.appId, secretKeyHolder.getSecretKey(challenge.appId), conf);
+ ServerResponse response = engine.respond(challenge);
+ ByteBuf responseData = Unpooled.buffer(response.encodedLength());
+ response.encode(responseData);
+ callback.onSuccess(responseData.nioBuffer());
+ engine.sessionCipher().addToChannel(channel);
+ } catch (Exception e) {
+ // This is a fatal error: authentication has failed. Close the channel explicitly.
+ LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress());
+ callback.onFailure(new IllegalArgumentException("Authentication failed."));
+ channel.close();
+ return;
+ } finally {
+ if (engine != null) {
+ try {
+ engine.close();
+ } catch (Exception e) {
+ throw Throwables.propagate(e);
+ }
+ }
+ }
+
+ LOG.debug("Authorization successful for client {}.", channel.remoteAddress());
+ doDelegate = true;
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message) {
+ delegate.receive(client, message);
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return delegate.getStreamManager();
+ }
+
+ @Override
+ public void channelActive(TransportClient client) {
+ delegate.channelActive(client);
+ }
+
+ @Override
+ public void channelInactive(TransportClient client) {
+ delegate.channelInactive(client);
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause, TransportClient client) {
+ delegate.exceptionCaught(cause, client);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java
new file mode 100644
index 0000000000..77a2a6af4d
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import io.netty.channel.Channel;
+
+import org.apache.spark.network.sasl.SaslServerBootstrap;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a client connects
+ * to the server, enabling authentication using Spark's auth protocol (and optionally SASL for
+ * clients that don't support the new protocol).
+ *
+ * It also automatically falls back to SASL if the new encryption backend is disabled, so that
+ * callers only need to install this bootstrap when authentication is enabled.
+ */
+public class AuthServerBootstrap implements TransportServerBootstrap {
+
+ private final TransportConf conf;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public AuthServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
+ if (!conf.encryptionEnabled()) {
+ TransportServerBootstrap sasl = new SaslServerBootstrap(conf, secretKeyHolder);
+ return sasl.doBootstrap(channel, rpcHandler);
+ }
+
+ return new AuthRpcHandler(conf, channel, rpcHandler, secretKeyHolder);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java
new file mode 100644
index 0000000000..3312a5bd81
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.nio.ByteBuffer;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.Encoders;
+
+/**
+ * The client challenge message, used to initiate authentication.
+ *
+ * @see README.md
+ */
+public class ClientChallenge implements Encodable {
+ /** Serialization tag used to catch incorrect payloads. */
+ private static final byte TAG_BYTE = (byte) 0xFA;
+
+ public final String appId;
+ public final String kdf;
+ public final int iterations;
+ public final String cipher;
+ public final int keyLength;
+ public final byte[] nonce;
+ public final byte[] challenge;
+
+ public ClientChallenge(
+ String appId,
+ String kdf,
+ int iterations,
+ String cipher,
+ int keyLength,
+ byte[] nonce,
+ byte[] challenge) {
+ this.appId = appId;
+ this.kdf = kdf;
+ this.iterations = iterations;
+ this.cipher = cipher;
+ this.keyLength = keyLength;
+ this.nonce = nonce;
+ this.challenge = challenge;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 1 + 4 + 4 +
+ Encoders.Strings.encodedLength(appId) +
+ Encoders.Strings.encodedLength(kdf) +
+ Encoders.Strings.encodedLength(cipher) +
+ Encoders.ByteArrays.encodedLength(nonce) +
+ Encoders.ByteArrays.encodedLength(challenge);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeByte(TAG_BYTE);
+ Encoders.Strings.encode(buf, appId);
+ Encoders.Strings.encode(buf, kdf);
+ buf.writeInt(iterations);
+ Encoders.Strings.encode(buf, cipher);
+ buf.writeInt(keyLength);
+ Encoders.ByteArrays.encode(buf, nonce);
+ Encoders.ByteArrays.encode(buf, challenge);
+ }
+
+ public static ClientChallenge decodeMessage(ByteBuffer buffer) {
+ ByteBuf buf = Unpooled.wrappedBuffer(buffer);
+
+ if (buf.readByte() != TAG_BYTE) {
+ throw new IllegalArgumentException("Expected ClientChallenge, received something else.");
+ }
+
+ return new ClientChallenge(
+ Encoders.Strings.decode(buf),
+ Encoders.Strings.decode(buf),
+ buf.readInt(),
+ Encoders.Strings.decode(buf),
+ buf.readInt(),
+ Encoders.ByteArrays.decode(buf),
+ Encoders.ByteArrays.decode(buf));
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md
new file mode 100644
index 0000000000..14df703270
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md
@@ -0,0 +1,158 @@
+Spark Auth Protocol and AES Encryption Support
+==============================================
+
+This file describes an auth protocol used by Spark as a more secure alternative to DIGEST-MD5. This
+protocol is built on symmetric key encryption, based on the assumption that the two endpoints being
+authenticated share a common secret, which is how Spark authentication currently works. The protocol
+provides mutual authentication, meaning that after the negotiation both parties know that the remote
+side knows the shared secret. The protocol is influenced by the ISO/IEC 9798 protocol, although it's
+not an implementation of it.
+
+This protocol could be replaced with TLS PSK, except no PSK ciphers are available in the currently
+released JREs.
+
+The protocol aims at solving the following shortcomings in Spark's current usage of DIGEST-MD5:
+
+- MD5 is an aging hash algorithm with known weaknesses, and a more secure alternative is desired.
+- DIGEST-MD5 has a pre-defined set of ciphers for which it can generate keys. The only
+ viable, supported cipher these days is 3DES, and a more modern alternative is desired.
+- Encrypting AES session keys with 3DES doesn't solve the issue, since the weakest link
+ in the negotiation would still be MD5 and 3DES.
+
+The protocol assumes that the shared secret is generated and distributed in a secure manner.
+
+The protocol always negotiates encryption keys. If encryption is not desired, the existing
+SASL-based authentication, or no authentication at all, can be chosen instead.
+
+When messages are described below, it's expected that the implementation should support
+arbitrary sizes for fields that don't have a fixed size.
+
+Client Challenge
+----------------
+
+The auth negotiation is started by the client. The client starts by generating an encryption
+key based on the application's shared secret, and a nonce.
+
+ KEY = KDF(SECRET, SALT, KEY_LENGTH)
+
+Where:
+- KDF(): a key derivation function that takes a secret, a salt, a configurable number of
+ iterations, and a configurable key length.
+- SALT: a byte sequence used to salt the key derivation function.
+- KEY_LENGTH: length of the encryption key to generate.
+
+
+The client generates a message with the following content:
+
+ CLIENT_CHALLENGE = (
+ APP_ID,
+ KDF,
+ ITERATIONS,
+ CIPHER,
+ KEY_LENGTH,
+ ANONCE,
+ ENC(APP_ID || ANONCE || CHALLENGE))
+
+Where:
+
+- APP_ID: the application ID which the server uses to identify the shared secret.
+- KDF: the key derivation function described above.
+- ITERATIONS: number of iterations to run the KDF when generating keys.
+- CIPHER: the cipher used to encrypt data.
+- KEY_LENGTH: length of the encryption keys to generate, in bits.
+- ANONCE: the nonce used as the salt when generating the auth key.
+- ENC(): an encryption function that uses the cipher and the generated key. This function
+ will also be used in the definition of other messages below.
+- CHALLENGE: a byte sequence used as a challenge to the server.
+- ||: concatenation operator.
+
+When strings are used where byte arrays are expected, the UTF-8 representation of the string
+is assumed.
+
+To respond to the challenge, the server should consider the byte array as representing an
+arbitrary-length integer, and respond with the value of the integer plus one.
+
+
+Server Response And Challenge
+-----------------------------
+
+Once the client challenge is received, the server will generate the same auth key by
+using the same algorithm the client has used. It will then verify the client challenge:
+if the APP_ID and ANONCE fields match, the server knows that the client has the shared
+secret. The server then creates a response to the client challenge, to prove that it also
+has the secret key, and provides parameters to be used when creating the session key.
+
+The following describes the response from the server:
+
+ SERVER_CHALLENGE = (
+ ENC(APP_ID || ANONCE || RESPONSE),
+ ENC(SNONCE),
+ ENC(INIV),
+ ENC(OUTIV))
+
+Where:
+
+- RESPONSE: the server's response to the client challenge.
+- SNONCE: a nonce to be used as salt when generating the session key.
+- INIV: initialization vector used to initialize the input channel of the client.
+- OUTIV: initialization vector used to initialize the output channel of the client.
+
+At this point the server considers the client to be authenticated, and will try to
+decrypt any data further sent by the client using the session key.
+
+
+Default Algorithms
+------------------
+
+Configuration options are available for the KDF and cipher algorithms to use.
+
+The default KDF is "PBKDF2WithHmacSHA1". Users should be able to select any algorithm
+from those supported by the `javax.crypto.SecretKeyFactory` class, as long as they support
+PBEKeySpec when generating keys. The default number of iterations was chosen to take a
+reasonable amount of time on modern CPUs. See the documentation in TransportConf for more
+details.
+
+The default cipher algorithm is "AES/CTR/NoPadding". Users should be able to select any
+algorithm supported by the commons-crypto library. It should allow the cipher to operate
+in stream mode.
+
+The default key length is 128 (bits).
+
+
+Implementation Details
+----------------------
+
+The commons-crypto library currently only supports AES ciphers, and requires an initialization
+vector (IV). This first version of the protocol does not explicitly include the IV in the client
+challenge message. Instead, the IV should be derived from the nonce, including the needed bytes, and
+padding the IV with zeroes in case the nonce is not long enough.
+
+Future versions of the protocol might add support for new ciphers and explicitly include needed
+configuration parameters in the messages.
+
+
+Threat Assessment
+-----------------
+
+The protocol is secure against different forms of attack:
+
+* Eavesdropping: the protocol is built on the assumption that it's computationally infeasible
+ to calculate the original secret from the encrypted messages. Neither the secret nor any
+ encryption keys are transmitted on the wire, encrypted or not.
+
+* Man-in-the-middle: because the protocol performs mutual authentication, both ends need to
+ know the shared secret to be able to decrypt session data. Even if an attacker is able to insert a
+ malicious "proxy" between endpoints, the attacker won't be able to read any of the data exchanged
+ between client and server, nor insert arbitrary commands for the server to execute.
+
+* Replay attacks: the use of nonces when generating keys prevents an attacker from being able to
+ just replay messages sniffed from the communication channel.
+
+An attacker may replay the client challenge and successfully "prove" to a server that it "knows" the
+shared secret. But the attacker won't be able to decrypt the server's response, and thus won't be
+able to generate a session key, which will make it hard to craft a valid, encrypted message that the
+server will be able to understand. This will cause the server to close the connection as soon as the
+attacker tries to send any command to the server. The attacker can just hold the channel open for
+some time, which will be closed when the server times out the channel. These issues could be
+separately mitigated by adding a shorter timeout for the first message after authentication, and
+potentially by adding host blacklists if a possible attack is detected from a particular host.
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java
new file mode 100644
index 0000000000..affdbf450b
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import java.nio.ByteBuffer;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.Encoders;
+
+/**
+ * Server's response to client's challenge.
+ *
+ * @see README.md
+ */
+public class ServerResponse implements Encodable {
+ /** Serialization tag used to catch incorrect payloads. */
+ private static final byte TAG_BYTE = (byte) 0xFB;
+
+ public final byte[] response;
+ public final byte[] nonce;
+ public final byte[] inputIv;
+ public final byte[] outputIv;
+
+ public ServerResponse(
+ byte[] response,
+ byte[] nonce,
+ byte[] inputIv,
+ byte[] outputIv) {
+ this.response = response;
+ this.nonce = nonce;
+ this.inputIv = inputIv;
+ this.outputIv = outputIv;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 1 +
+ Encoders.ByteArrays.encodedLength(response) +
+ Encoders.ByteArrays.encodedLength(nonce) +
+ Encoders.ByteArrays.encodedLength(inputIv) +
+ Encoders.ByteArrays.encodedLength(outputIv);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeByte(TAG_BYTE);
+ Encoders.ByteArrays.encode(buf, response);
+ Encoders.ByteArrays.encode(buf, nonce);
+ Encoders.ByteArrays.encode(buf, inputIv);
+ Encoders.ByteArrays.encode(buf, outputIv);
+ }
+
+ public static ServerResponse decodeMessage(ByteBuffer buffer) {
+ ByteBuf buf = Unpooled.wrappedBuffer(buffer);
+
+ if (buf.readByte() != TAG_BYTE) {
+ throw new IllegalArgumentException("Expected ServerResponse, received something else.");
+ }
+
+ return new ServerResponse(
+ Encoders.ByteArrays.decode(buf),
+ Encoders.ByteArrays.decode(buf),
+ Encoders.ByteArrays.decode(buf),
+ Encoders.ByteArrays.decode(buf));
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
index 340986a63b..7376d1ddc4 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.network.sasl.aes;
+package org.apache.spark.network.crypto;
import java.io.IOException;
import java.nio.ByteBuffer;
@@ -25,115 +25,91 @@ import java.util.Properties;
import javax.crypto.spec.SecretKeySpec;
import javax.crypto.spec.IvParameterSpec;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
-import com.google.common.base.Throwables;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.util.AbstractReferenceCounted;
-import org.apache.commons.crypto.cipher.CryptoCipherFactory;
-import org.apache.commons.crypto.random.CryptoRandom;
-import org.apache.commons.crypto.random.CryptoRandomFactory;
import org.apache.commons.crypto.stream.CryptoInputStream;
import org.apache.commons.crypto.stream.CryptoOutputStream;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
import org.apache.spark.network.util.ByteArrayReadableChannel;
import org.apache.spark.network.util.ByteArrayWritableChannel;
-import org.apache.spark.network.util.TransportConf;
/**
- * AES cipher for encryption and decryption.
+ * Cipher for encryption and decryption.
*/
-public class AesCipher {
- private static final Logger logger = LoggerFactory.getLogger(AesCipher.class);
- public static final String ENCRYPTION_HANDLER_NAME = "AesEncryption";
- public static final String DECRYPTION_HANDLER_NAME = "AesDecryption";
- public static final int STREAM_BUFFER_SIZE = 1024 * 32;
- public static final String TRANSFORM = "AES/CTR/NoPadding";
-
- private final SecretKeySpec inKeySpec;
- private final IvParameterSpec inIvSpec;
- private final SecretKeySpec outKeySpec;
- private final IvParameterSpec outIvSpec;
- private final Properties properties;
-
- public AesCipher(AesConfigMessage configMessage, TransportConf conf) throws IOException {
- this.properties = conf.cryptoConf();
- this.inKeySpec = new SecretKeySpec(configMessage.inKey, "AES");
- this.inIvSpec = new IvParameterSpec(configMessage.inIv);
- this.outKeySpec = new SecretKeySpec(configMessage.outKey, "AES");
- this.outIvSpec = new IvParameterSpec(configMessage.outIv);
+public class TransportCipher {
+ @VisibleForTesting
+ static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption";
+ private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption";
+ private static final int STREAM_BUFFER_SIZE = 1024 * 32;
+
+ private final Properties conf;
+ private final String cipher;
+ private final SecretKeySpec key;
+ private final byte[] inIv;
+ private final byte[] outIv;
+
+ public TransportCipher(
+ Properties conf,
+ String cipher,
+ SecretKeySpec key,
+ byte[] inIv,
+ byte[] outIv) {
+ this.conf = conf;
+ this.cipher = cipher;
+ this.key = key;
+ this.inIv = inIv;
+ this.outIv = outIv;
+ }
+
+ public String getCipherTransformation() {
+ return cipher;
+ }
+
+ @VisibleForTesting
+ SecretKeySpec getKey() {
+ return key;
+ }
+
+ /** The IV for the input channel (i.e. output channel of the remote side). */
+ public byte[] getInputIv() {
+ return inIv;
+ }
+
+ /** The IV for the output channel (i.e. input channel of the remote side). */
+ public byte[] getOutputIv() {
+ return outIv;
}
- /**
- * Create AES crypto output stream
- * @param ch The underlying channel to write out.
- * @return Return output crypto stream for encryption.
- * @throws IOException
- */
private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException {
- return new CryptoOutputStream(TRANSFORM, properties, ch, outKeySpec, outIvSpec);
+ return new CryptoOutputStream(cipher, conf, ch, key, new IvParameterSpec(outIv));
}
- /**
- * Create AES crypto input stream
- * @param ch The underlying channel used to read data.
- * @return Return input crypto stream for decryption.
- * @throws IOException
- */
private CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException {
- return new CryptoInputStream(TRANSFORM, properties, ch, inKeySpec, inIvSpec);
+ return new CryptoInputStream(cipher, conf, ch, key, new IvParameterSpec(inIv));
}
/**
- * Add handlers to channel
+ * Add handlers to channel.
+ *
* @param ch the channel for adding handlers
* @throws IOException
*/
public void addToChannel(Channel ch) throws IOException {
ch.pipeline()
- .addFirst(ENCRYPTION_HANDLER_NAME, new AesEncryptHandler(this))
- .addFirst(DECRYPTION_HANDLER_NAME, new AesDecryptHandler(this));
- }
-
- /**
- * Create the configuration message
- * @param conf is the local transport configuration.
- * @return Config message for sending.
- */
- public static AesConfigMessage createConfigMessage(TransportConf conf) {
- int keySize = conf.aesCipherKeySize();
- Properties properties = conf.cryptoConf();
-
- try {
- int paramLen = CryptoCipherFactory.getCryptoCipher(AesCipher.TRANSFORM, properties)
- .getBlockSize();
- byte[] inKey = new byte[keySize];
- byte[] outKey = new byte[keySize];
- byte[] inIv = new byte[paramLen];
- byte[] outIv = new byte[paramLen];
-
- CryptoRandom random = CryptoRandomFactory.getCryptoRandom(properties);
- random.nextBytes(inKey);
- random.nextBytes(outKey);
- random.nextBytes(inIv);
- random.nextBytes(outIv);
-
- return new AesConfigMessage(inKey, inIv, outKey, outIv);
- } catch (Exception e) {
- logger.error("AES config error", e);
- throw Throwables.propagate(e);
- }
+ .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(this))
+ .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this));
}
- private static class AesEncryptHandler extends ChannelOutboundHandlerAdapter {
+ private static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
private final ByteArrayWritableChannel byteChannel;
private final CryptoOutputStream cos;
- AesEncryptHandler(AesCipher cipher) throws IOException {
- byteChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE);
+ EncryptionHandler(TransportCipher cipher) throws IOException {
+ byteChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
cos = cipher.createOutputStream(byteChannel);
}
@@ -153,11 +129,11 @@ public class AesCipher {
}
}
- private static class AesDecryptHandler extends ChannelInboundHandlerAdapter {
+ private static class DecryptionHandler extends ChannelInboundHandlerAdapter {
private final CryptoInputStream cis;
private final ByteArrayReadableChannel byteChannel;
- AesDecryptHandler(AesCipher cipher) throws IOException {
+ DecryptionHandler(TransportCipher cipher) throws IOException {
byteChannel = new ByteArrayReadableChannel();
cis = cipher.createInputStream(byteChannel);
}
@@ -207,7 +183,7 @@ public class AesCipher {
this.buf = isByteBuf ? (ByteBuf) msg : null;
this.region = isByteBuf ? null : (FileRegion) msg;
this.transferred = 0;
- this.byteRawChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE);
+ this.byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
this.cos = cos;
this.byteEncChannel = ch;
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index a1bb453657..6478137722 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -30,8 +30,6 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
-import org.apache.spark.network.sasl.aes.AesCipher;
-import org.apache.spark.network.sasl.aes.AesConfigMessage;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.TransportConf;
@@ -42,24 +40,14 @@ import org.apache.spark.network.util.TransportConf;
public class SaslClientBootstrap implements TransportClientBootstrap {
private static final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);
- private final boolean encrypt;
private final TransportConf conf;
private final String appId;
private final SecretKeyHolder secretKeyHolder;
public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
- this(conf, appId, secretKeyHolder, false);
- }
-
- public SaslClientBootstrap(
- TransportConf conf,
- String appId,
- SecretKeyHolder secretKeyHolder,
- boolean encrypt) {
this.conf = conf;
this.appId = appId;
this.secretKeyHolder = secretKeyHolder;
- this.encrypt = encrypt;
}
/**
@@ -69,7 +57,7 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
*/
@Override
public void doBootstrap(TransportClient client, Channel channel) {
- SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt);
+ SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption());
try {
byte[] payload = saslClient.firstToken();
@@ -79,35 +67,19 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
msg.encode(buf);
buf.writeBytes(msg.body().nioByteBuffer());
- ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs());
+ ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs());
payload = saslClient.response(JavaUtils.bufferToArray(response));
}
client.setClientId(appId);
- if (encrypt) {
+ if (conf.saslEncryption()) {
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
throw new RuntimeException(
new SaslException("Encryption requests by negotiated non-encrypted connection."));
}
- if (conf.aesEncryptionEnabled()) {
- // Generate a request config message to send to server.
- AesConfigMessage configMessage = AesCipher.createConfigMessage(conf);
- ByteBuffer buf = configMessage.encodeMessage();
-
- // Encrypted the config message.
- byte[] toEncrypt = JavaUtils.bufferToArray(buf);
- ByteBuffer encrypted = ByteBuffer.wrap(saslClient.wrap(toEncrypt, 0, toEncrypt.length));
-
- client.sendRpcSync(encrypted, conf.saslRTTimeoutMs());
- AesCipher cipher = new AesCipher(configMessage, conf);
- logger.info("Enabling AES cipher for client channel {}", client);
- cipher.addToChannel(channel);
- saslClient.dispose();
- } else {
- SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
- }
+ SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
saslClient = null;
logger.debug("Channel {} configured for encryption.", client);
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index b2f3ef214b..0231428318 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -29,8 +29,6 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.sasl.aes.AesCipher;
-import org.apache.spark.network.sasl.aes.AesConfigMessage;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.JavaUtils;
@@ -44,7 +42,7 @@ import org.apache.spark.network.util.TransportConf;
* Note that the authentication process consists of multiple challenge-response pairs, each of
* which are individual RPCs.
*/
-class SaslRpcHandler extends RpcHandler {
+public class SaslRpcHandler extends RpcHandler {
private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
/** Transport configuration. */
@@ -63,7 +61,7 @@ class SaslRpcHandler extends RpcHandler {
private boolean isComplete;
private boolean isAuthenticated;
- SaslRpcHandler(
+ public SaslRpcHandler(
TransportConf conf,
Channel channel,
RpcHandler delegate,
@@ -122,37 +120,10 @@ class SaslRpcHandler extends RpcHandler {
return;
}
- if (!conf.aesEncryptionEnabled()) {
- logger.debug("Enabling encryption for channel {}", client);
- SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
- complete(false);
- return;
- }
-
- // Extra negotiation should happen after authentication, so return directly while
- // processing authenticate.
- if (!isAuthenticated) {
- logger.debug("SASL authentication successful for channel {}", client);
- isAuthenticated = true;
- return;
- }
-
- // Create AES cipher when it is authenticated
- try {
- byte[] encrypted = JavaUtils.bufferToArray(message);
- ByteBuffer decrypted = ByteBuffer.wrap(saslServer.unwrap(encrypted, 0 , encrypted.length));
-
- AesConfigMessage configMessage = AesConfigMessage.decodeMessage(decrypted);
- AesCipher cipher = new AesCipher(configMessage, conf);
-
- // Send response back to client to confirm that server accept config.
- callback.onSuccess(JavaUtils.stringToBytes(AesCipher.TRANSFORM));
- logger.info("Enabling AES cipher for Server channel {}", client);
- cipher.addToChannel(channel);
- complete(true);
- } catch (IOException ioe) {
- throw new RuntimeException(ioe);
- }
+ logger.debug("Enabling encryption for channel {}", client);
+ SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
+ complete(false);
+ return;
}
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java
deleted file mode 100644
index 3ef6f74a1f..0000000000
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.sasl.aes;
-
-import java.nio.ByteBuffer;
-
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-
-import org.apache.spark.network.protocol.Encodable;
-import org.apache.spark.network.protocol.Encoders;
-
-/**
- * The AES cipher options for encryption negotiation.
- */
-public class AesConfigMessage implements Encodable {
- /** Serialization tag used to catch incorrect payloads. */
- private static final byte TAG_BYTE = (byte) 0xEB;
-
- public byte[] inKey;
- public byte[] outKey;
- public byte[] inIv;
- public byte[] outIv;
-
- public AesConfigMessage(byte[] inKey, byte[] inIv, byte[] outKey, byte[] outIv) {
- if (inKey == null || inIv == null || outKey == null || outIv == null) {
- throw new IllegalArgumentException("Cipher Key or IV must not be null!");
- }
-
- this.inKey = inKey;
- this.inIv = inIv;
- this.outKey = outKey;
- this.outIv = outIv;
- }
-
- @Override
- public int encodedLength() {
- return 1 +
- Encoders.ByteArrays.encodedLength(inKey) + Encoders.ByteArrays.encodedLength(outKey) +
- Encoders.ByteArrays.encodedLength(inIv) + Encoders.ByteArrays.encodedLength(outIv);
- }
-
- @Override
- public void encode(ByteBuf buf) {
- buf.writeByte(TAG_BYTE);
- Encoders.ByteArrays.encode(buf, inKey);
- Encoders.ByteArrays.encode(buf, inIv);
- Encoders.ByteArrays.encode(buf, outKey);
- Encoders.ByteArrays.encode(buf, outIv);
- }
-
- /**
- * Encode the config message.
- * @return ByteBuffer which contains encoded config message.
- */
- public ByteBuffer encodeMessage(){
- ByteBuffer buf = ByteBuffer.allocate(encodedLength());
-
- ByteBuf wrappedBuf = Unpooled.wrappedBuffer(buf);
- wrappedBuf.clear();
- encode(wrappedBuf);
-
- return buf;
- }
-
- /**
- * Decode the config message from buffer
- * @param buffer the buffer contain encoded config message
- * @return config message
- */
- public static AesConfigMessage decodeMessage(ByteBuffer buffer) {
- ByteBuf buf = Unpooled.wrappedBuffer(buffer);
-
- if (buf.readByte() != TAG_BYTE) {
- throw new IllegalStateException("Expected AesConfigMessage, received something else"
- + " (maybe your client does not have AES enabled?)");
- }
-
- byte[] outKey = Encoders.ByteArrays.decode(buf);
- byte[] outIv = Encoders.ByteArrays.decode(buf);
- byte[] inKey = Encoders.ByteArrays.decode(buf);
- byte[] inIv = Encoders.ByteArrays.decode(buf);
- return new AesConfigMessage(inKey, inIv, outKey, outIv);
- }
-
-}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 6a557fa75d..c226d8f3bc 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -117,9 +117,10 @@ public class TransportConf {
/** Send buffer size (SO_SNDBUF). */
public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); }
- /** Timeout for a single round trip of SASL token exchange, in milliseconds. */
- public int saslRTTimeoutMs() {
- return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000;
+ /** Timeout for a single round trip of auth message exchange, in milliseconds. */
+ public int authRTTimeoutMs() {
+ return (int) JavaUtils.timeStringAsSec(conf.get("spark.network.auth.rpcTimeout",
+ conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s"))) * 1000;
}
/**
@@ -162,40 +163,95 @@ public class TransportConf {
}
/**
- * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled.
+ * Enables strong encryption. Also enables the new auth protocol, used to negotiate keys.
*/
- public int maxSaslEncryptedBlockSize() {
- return Ints.checkedCast(JavaUtils.byteStringAsBytes(
- conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k")));
+ public boolean encryptionEnabled() {
+ return conf.getBoolean("spark.network.crypto.enabled", false);
}
/**
- * Whether the server should enforce encryption on SASL-authenticated connections.
+ * The cipher transformation to use for encrypting session data.
*/
- public boolean saslServerAlwaysEncrypt() {
- return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
+ public String cipherTransformation() {
+ return conf.get("spark.network.crypto.cipher", "AES/CTR/NoPadding");
+ }
+
+ /**
+ * The key generation algorithm. This should be an algorithm that accepts a "PBEKeySpec"
+ * as input. The default value (PBKDF2WithHmacSHA1) is available in Java 7.
+ */
+ public String keyFactoryAlgorithm() {
+ return conf.get("spark.network.crypto.keyFactoryAlgorithm", "PBKDF2WithHmacSHA1");
+ }
+
+ /**
+ * How many iterations to run when generating keys.
+ *
+ * See some discussion about this at: http://security.stackexchange.com/q/3959
+ * The default value was picked for speed, since it assumes that the secret has good entropy
+ * (128 bits by default), which is not generally the case with user passwords.
+ */
+ public int keyFactoryIterations() {
+ return conf.getInt("spark.networy.crypto.keyFactoryIterations", 1024);
+ }
+
+ /**
+ * Encryption key length, in bits.
+ */
+ public int encryptionKeyLength() {
+ return conf.getInt("spark.network.crypto.keyLength", 128);
+ }
+
+ /**
+ * Initial vector length, in bytes.
+ */
+ public int ivLength() {
+ return conf.getInt("spark.network.crypto.ivLength", 16);
+ }
+
+ /**
+ * The algorithm for generated secret keys. Nobody should really need to change this,
+ * but configurable just in case.
+ */
+ public String keyAlgorithm() {
+ return conf.get("spark.network.crypto.keyAlgorithm", "AES");
+ }
+
+ /**
+ * Whether to fall back to SASL if the new auth protocol fails. Enabled by default for
+ * backwards compatibility.
+ */
+ public boolean saslFallback() {
+ return conf.getBoolean("spark.network.crypto.saslFallback", true);
}
/**
- * The trigger for enabling AES encryption.
+ * Whether to enable SASL-based encryption when authenticating using SASL.
*/
- public boolean aesEncryptionEnabled() {
- return conf.getBoolean("spark.network.aes.enabled", false);
+ public boolean saslEncryption() {
+ return conf.getBoolean("spark.authenticate.enableSaslEncryption", false);
}
/**
- * The key size to use when AES cipher is enabled. Notice that the length should be 16, 24 or 32
- * bytes.
+ * Maximum number of bytes to be encrypted at a time when SASL encryption is used.
*/
- public int aesCipherKeySize() {
- return conf.getInt("spark.network.aes.keySize", 16);
+ public int maxSaslEncryptedBlockSize() {
+ return Ints.checkedCast(JavaUtils.byteStringAsBytes(
+ conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k")));
+ }
+
+ /**
+ * Whether the server should enforce encryption on SASL-authenticated connections.
+ */
+ public boolean saslServerAlwaysEncrypt() {
+ return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
}
/**
* The commons-crypto configuration for the module.
*/
public Properties cryptoConf() {
- return CryptoUtils.toCryptoConf("spark.network.aes.config.", conf.getAll());
+ return CryptoUtils.toCryptoConf("spark.network.crypto.config.", conf.getAll());
}
}