diff options
Diffstat (limited to 'common/network-common/src/main/java/org/apache/spark')
12 files changed, 1122 insertions, 267 deletions
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java new file mode 100644 index 0000000000..980525dbf0 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; +import java.security.Key; +import javax.crypto.KeyGenerator; +import javax.crypto.Mac; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.sasl.SaslClientBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Bootstraps a {@link TransportClient} by performing authentication using Spark's auth protocol. + * + * This bootstrap falls back to using the SASL bootstrap if the server throws an error during + * authentication, and the configuration allows it. This is used for backwards compatibility + * with external shuffle services that do not support the new protocol. + * + * It also automatically falls back to SASL if the new encryption backend is disabled, so that + * callers only need to install this bootstrap when authentication is enabled. + */ +public class AuthClientBootstrap implements TransportClientBootstrap { + + private static final Logger LOG = LoggerFactory.getLogger(AuthClientBootstrap.class); + + private final TransportConf conf; + private final String appId; + private final String authUser; + private final SecretKeyHolder secretKeyHolder; + + public AuthClientBootstrap( + TransportConf conf, + String appId, + SecretKeyHolder secretKeyHolder) { + this.conf = conf; + // TODO: right now this behaves like the SASL backend, because when executors start up + // they don't necessarily know the app ID. So they send a hardcoded "user" that is defined + // in the SecurityManager, which will also always return the same secret (regardless of the + // user name). All that's needed here is for this "user" to match on both sides, since that's + // required by the protocol. At some point, though, it would be better for the actual app ID + // to be provided here. + this.appId = appId; + this.authUser = secretKeyHolder.getSaslUser(appId); + this.secretKeyHolder = secretKeyHolder; + } + + @Override + public void doBootstrap(TransportClient client, Channel channel) { + if (!conf.encryptionEnabled()) { + LOG.debug("AES encryption disabled, using old auth protocol."); + doSaslAuth(client, channel); + return; + } + + try { + doSparkAuth(client, channel); + } catch (GeneralSecurityException | IOException e) { + throw Throwables.propagate(e); + } catch (RuntimeException e) { + // There isn't a good exception that can be caught here to know whether it's really + // OK to switch back to SASL (because the server doesn't speak the new protocol). So + // try it anyway, and in the worst case things will fail again. + if (conf.saslFallback()) { + LOG.warn("New auth protocol failed, trying SASL.", e); + doSaslAuth(client, channel); + } else { + throw e; + } + } + } + + private void doSparkAuth(TransportClient client, Channel channel) + throws GeneralSecurityException, IOException { + + AuthEngine engine = new AuthEngine(authUser, secretKeyHolder.getSecretKey(authUser), conf); + try { + ClientChallenge challenge = engine.challenge(); + ByteBuf challengeData = Unpooled.buffer(challenge.encodedLength()); + challenge.encode(challengeData); + + ByteBuffer responseData = client.sendRpcSync(challengeData.nioBuffer(), + conf.authRTTimeoutMs()); + ServerResponse response = ServerResponse.decodeMessage(responseData); + + engine.validate(response); + engine.sessionCipher().addToChannel(channel); + } finally { + engine.close(); + } + } + + private void doSaslAuth(TransportClient client, Channel channel) { + SaslClientBootstrap sasl = new SaslClientBootstrap(conf, appId, secretKeyHolder); + sasl.doBootstrap(client, channel); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java new file mode 100644 index 0000000000..b769ebeba3 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.io.Closeable; +import java.io.IOException; +import java.math.BigInteger; +import java.security.GeneralSecurityException; +import java.util.Arrays; +import java.util.Properties; +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import javax.crypto.SecretKeyFactory; +import javax.crypto.ShortBufferException; +import javax.crypto.spec.IvParameterSpec; +import javax.crypto.spec.PBEKeySpec; +import javax.crypto.spec.SecretKeySpec; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.primitives.Bytes; +import org.apache.commons.crypto.cipher.CryptoCipher; +import org.apache.commons.crypto.cipher.CryptoCipherFactory; +import org.apache.commons.crypto.random.CryptoRandom; +import org.apache.commons.crypto.random.CryptoRandomFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.util.TransportConf; + +/** + * A helper class for abstracting authentication and key negotiation details. This is used by + * both client and server sides, since the operations are basically the same. + */ +class AuthEngine implements Closeable { + + private static final Logger LOG = LoggerFactory.getLogger(AuthEngine.class); + private static final BigInteger ONE = new BigInteger(new byte[] { 0x1 }); + + private final byte[] appId; + private final char[] secret; + private final TransportConf conf; + private final Properties cryptoConf; + private final CryptoRandom random; + + private byte[] authNonce; + + @VisibleForTesting + byte[] challenge; + + private TransportCipher sessionCipher; + private CryptoCipher encryptor; + private CryptoCipher decryptor; + + AuthEngine(String appId, String secret, TransportConf conf) throws GeneralSecurityException { + this.appId = appId.getBytes(UTF_8); + this.conf = conf; + this.cryptoConf = conf.cryptoConf(); + this.secret = secret.toCharArray(); + this.random = CryptoRandomFactory.getCryptoRandom(cryptoConf); + } + + /** + * Create the client challenge. + * + * @return A challenge to be sent the remote side. + */ + ClientChallenge challenge() throws GeneralSecurityException, IOException { + this.authNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE); + SecretKeySpec authKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(), + authNonce, conf.encryptionKeyLength()); + initializeForAuth(conf.cipherTransformation(), authNonce, authKey); + + this.challenge = randomBytes(conf.encryptionKeyLength() / Byte.SIZE); + return new ClientChallenge(new String(appId, UTF_8), + conf.keyFactoryAlgorithm(), + conf.keyFactoryIterations(), + conf.cipherTransformation(), + conf.encryptionKeyLength(), + authNonce, + challenge(appId, authNonce, challenge)); + } + + /** + * Validates the client challenge, and create the encryption backend for the channel from the + * parameters sent by the client. + * + * @param clientChallenge The challenge from the client. + * @return A response to be sent to the client. + */ + ServerResponse respond(ClientChallenge clientChallenge) + throws GeneralSecurityException, IOException { + + SecretKeySpec authKey = generateKey(clientChallenge.kdf, clientChallenge.iterations, + clientChallenge.nonce, clientChallenge.keyLength); + initializeForAuth(clientChallenge.cipher, clientChallenge.nonce, authKey); + + byte[] challenge = validateChallenge(clientChallenge.nonce, clientChallenge.challenge); + byte[] response = challenge(appId, clientChallenge.nonce, rawResponse(challenge)); + byte[] sessionNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE); + byte[] inputIv = randomBytes(conf.ivLength()); + byte[] outputIv = randomBytes(conf.ivLength()); + + SecretKeySpec sessionKey = generateKey(clientChallenge.kdf, clientChallenge.iterations, + sessionNonce, clientChallenge.keyLength); + this.sessionCipher = new TransportCipher(cryptoConf, clientChallenge.cipher, sessionKey, + inputIv, outputIv); + + // Note the IVs are swapped in the response. + return new ServerResponse(response, encrypt(sessionNonce), encrypt(outputIv), encrypt(inputIv)); + } + + /** + * Validates the server response and initializes the cipher to use for the session. + * + * @param serverResponse The response from the server. + */ + void validate(ServerResponse serverResponse) throws GeneralSecurityException { + byte[] response = validateChallenge(authNonce, serverResponse.response); + + byte[] expected = rawResponse(challenge); + Preconditions.checkArgument(Arrays.equals(expected, response)); + + byte[] nonce = decrypt(serverResponse.nonce); + byte[] inputIv = decrypt(serverResponse.inputIv); + byte[] outputIv = decrypt(serverResponse.outputIv); + + SecretKeySpec sessionKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(), + nonce, conf.encryptionKeyLength()); + this.sessionCipher = new TransportCipher(cryptoConf, conf.cipherTransformation(), sessionKey, + inputIv, outputIv); + } + + TransportCipher sessionCipher() { + Preconditions.checkState(sessionCipher != null); + return sessionCipher; + } + + @Override + public void close() throws IOException { + // Close ciphers (by calling "doFinal()" with dummy data) and the random instance so that + // internal state is cleaned up. Error handling here is just for paranoia, and not meant to + // accurately report the errors when they happen. + RuntimeException error = null; + byte[] dummy = new byte[8]; + try { + doCipherOp(encryptor, dummy, true); + } catch (Exception e) { + error = new RuntimeException(e); + } + try { + doCipherOp(decryptor, dummy, true); + } catch (Exception e) { + error = new RuntimeException(e); + } + random.close(); + + if (error != null) { + throw error; + } + } + + @VisibleForTesting + byte[] challenge(byte[] appId, byte[] nonce, byte[] challenge) throws GeneralSecurityException { + return encrypt(Bytes.concat(appId, nonce, challenge)); + } + + @VisibleForTesting + byte[] rawResponse(byte[] challenge) { + BigInteger orig = new BigInteger(challenge); + BigInteger response = orig.add(ONE); + return response.toByteArray(); + } + + private byte[] decrypt(byte[] in) throws GeneralSecurityException { + return doCipherOp(decryptor, in, false); + } + + private byte[] encrypt(byte[] in) throws GeneralSecurityException { + return doCipherOp(encryptor, in, false); + } + + private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key) + throws GeneralSecurityException { + + // commons-crypto currently only supports ciphers that require an initial vector; so + // create a dummy vector so that we can initialize the ciphers. In the future, if + // different ciphers are supported, this will have to be configurable somehow. + byte[] iv = new byte[conf.ivLength()]; + System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length)); + + encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); + encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv)); + + decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); + decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv)); + } + + /** + * Validates an encrypted challenge as defined in the protocol, and returns the byte array + * that corresponds to the actual challenge data. + */ + private byte[] validateChallenge(byte[] nonce, byte[] encryptedChallenge) + throws GeneralSecurityException { + + byte[] challenge = decrypt(encryptedChallenge); + checkSubArray(appId, challenge, 0); + checkSubArray(nonce, challenge, appId.length); + return Arrays.copyOfRange(challenge, appId.length + nonce.length, challenge.length); + } + + private SecretKeySpec generateKey(String kdf, int iterations, byte[] salt, int keyLength) + throws GeneralSecurityException { + + SecretKeyFactory factory = SecretKeyFactory.getInstance(kdf); + PBEKeySpec spec = new PBEKeySpec(secret, salt, iterations, keyLength); + + long start = System.nanoTime(); + SecretKey key = factory.generateSecret(spec); + long end = System.nanoTime(); + + LOG.debug("Generated key with {} iterations in {} us.", conf.keyFactoryIterations(), + (end - start) / 1000); + + return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm()); + } + + private byte[] doCipherOp(CryptoCipher cipher, byte[] in, boolean isFinal) + throws GeneralSecurityException { + + Preconditions.checkState(cipher != null); + + int scale = 1; + while (true) { + int size = in.length * scale; + byte[] buffer = new byte[size]; + try { + int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0) + : cipher.update(in, 0, in.length, buffer, 0); + if (outSize != buffer.length) { + byte[] output = new byte[outSize]; + System.arraycopy(buffer, 0, output, 0, output.length); + return output; + } else { + return buffer; + } + } catch (ShortBufferException e) { + // Try again with a bigger buffer. + scale *= 2; + } + } + } + + private byte[] randomBytes(int count) { + byte[] bytes = new byte[count]; + random.nextBytes(bytes); + return bytes; + } + + /** Checks that the "test" array is in the data array starting at the given offset. */ + private void checkSubArray(byte[] test, byte[] data, int offset) { + Preconditions.checkArgument(data.length >= test.length + offset); + for (int i = 0; i < test.length; i++) { + Preconditions.checkArgument(test[i] == data[i + offset]); + } + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java new file mode 100644 index 0000000000..991d8ba95f --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.io.IOException; +import java.nio.ByteBuffer; +import javax.security.sasl.Sasl; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Throwables; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * RPC Handler which performs authentication using Spark's auth protocol before delegating to a + * child RPC handler. If the configuration allows, this handler will delegate messages to a SASL + * RPC handler for further authentication, to support for clients that do not support Spark's + * protocol. + * + * The delegate will only receive messages if the given connection has been successfully + * authenticated. A connection may be authenticated at most once. + */ +class AuthRpcHandler extends RpcHandler { + private static final Logger LOG = LoggerFactory.getLogger(AuthRpcHandler.class); + + /** Transport configuration. */ + private final TransportConf conf; + + /** The client channel. */ + private final Channel channel; + + /** + * RpcHandler we will delegate to for authenticated connections. When falling back to SASL + * this will be replaced with the SASL RPC handler. + */ + @VisibleForTesting + RpcHandler delegate; + + /** Class which provides secret keys which are shared by server and client on a per-app basis. */ + private final SecretKeyHolder secretKeyHolder; + + /** Whether auth is done and future calls should be delegated. */ + @VisibleForTesting + boolean doDelegate; + + AuthRpcHandler( + TransportConf conf, + Channel channel, + RpcHandler delegate, + SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.channel = channel; + this.delegate = delegate; + this.secretKeyHolder = secretKeyHolder; + } + + @Override + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + if (doDelegate) { + delegate.receive(client, message, callback); + return; + } + + int position = message.position(); + int limit = message.limit(); + + ClientChallenge challenge; + try { + challenge = ClientChallenge.decodeMessage(message); + LOG.debug("Received new auth challenge for client {}.", channel.remoteAddress()); + } catch (RuntimeException e) { + if (conf.saslFallback()) { + LOG.warn("Failed to parse new auth challenge, reverting to SASL for client {}.", + channel.remoteAddress()); + delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder); + message.position(position); + message.limit(limit); + delegate.receive(client, message, callback); + doDelegate = true; + } else { + LOG.debug("Unexpected challenge message from client {}, closing channel.", + channel.remoteAddress()); + callback.onFailure(new IllegalArgumentException("Unknown challenge message.")); + channel.close(); + } + return; + } + + // Here we have the client challenge, so perform the new auth protocol and set up the channel. + AuthEngine engine = null; + try { + engine = new AuthEngine(challenge.appId, secretKeyHolder.getSecretKey(challenge.appId), conf); + ServerResponse response = engine.respond(challenge); + ByteBuf responseData = Unpooled.buffer(response.encodedLength()); + response.encode(responseData); + callback.onSuccess(responseData.nioBuffer()); + engine.sessionCipher().addToChannel(channel); + } catch (Exception e) { + // This is a fatal error: authentication has failed. Close the channel explicitly. + LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress()); + callback.onFailure(new IllegalArgumentException("Authentication failed.")); + channel.close(); + return; + } finally { + if (engine != null) { + try { + engine.close(); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + } + + LOG.debug("Authorization successful for client {}.", channel.remoteAddress()); + doDelegate = true; + } + + @Override + public void receive(TransportClient client, ByteBuffer message) { + delegate.receive(client, message); + } + + @Override + public StreamManager getStreamManager() { + return delegate.getStreamManager(); + } + + @Override + public void channelActive(TransportClient client) { + delegate.channelActive(client); + } + + @Override + public void channelInactive(TransportClient client) { + delegate.channelInactive(client); + } + + @Override + public void exceptionCaught(Throwable cause, TransportClient client) { + delegate.exceptionCaught(cause, client); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java new file mode 100644 index 0000000000..77a2a6af4d --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import io.netty.channel.Channel; + +import org.apache.spark.network.sasl.SaslServerBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.util.TransportConf; + +/** + * A bootstrap which is executed on a TransportServer's client channel once a client connects + * to the server, enabling authentication using Spark's auth protocol (and optionally SASL for + * clients that don't support the new protocol). + * + * It also automatically falls back to SASL if the new encryption backend is disabled, so that + * callers only need to install this bootstrap when authentication is enabled. + */ +public class AuthServerBootstrap implements TransportServerBootstrap { + + private final TransportConf conf; + private final SecretKeyHolder secretKeyHolder; + + public AuthServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.secretKeyHolder = secretKeyHolder; + } + + public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { + if (!conf.encryptionEnabled()) { + TransportServerBootstrap sasl = new SaslServerBootstrap(conf, secretKeyHolder); + return sasl.doBootstrap(channel, rpcHandler); + } + + return new AuthRpcHandler(conf, channel, rpcHandler, secretKeyHolder); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java new file mode 100644 index 0000000000..3312a5bd81 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** + * The client challenge message, used to initiate authentication. + * + * @see README.md + */ +public class ClientChallenge implements Encodable { + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xFA; + + public final String appId; + public final String kdf; + public final int iterations; + public final String cipher; + public final int keyLength; + public final byte[] nonce; + public final byte[] challenge; + + public ClientChallenge( + String appId, + String kdf, + int iterations, + String cipher, + int keyLength, + byte[] nonce, + byte[] challenge) { + this.appId = appId; + this.kdf = kdf; + this.iterations = iterations; + this.cipher = cipher; + this.keyLength = keyLength; + this.nonce = nonce; + this.challenge = challenge; + } + + @Override + public int encodedLength() { + return 1 + 4 + 4 + + Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(kdf) + + Encoders.Strings.encodedLength(cipher) + + Encoders.ByteArrays.encodedLength(nonce) + + Encoders.ByteArrays.encodedLength(challenge); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, kdf); + buf.writeInt(iterations); + Encoders.Strings.encode(buf, cipher); + buf.writeInt(keyLength); + Encoders.ByteArrays.encode(buf, nonce); + Encoders.ByteArrays.encode(buf, challenge); + } + + public static ClientChallenge decodeMessage(ByteBuffer buffer) { + ByteBuf buf = Unpooled.wrappedBuffer(buffer); + + if (buf.readByte() != TAG_BYTE) { + throw new IllegalArgumentException("Expected ClientChallenge, received something else."); + } + + return new ClientChallenge( + Encoders.Strings.decode(buf), + Encoders.Strings.decode(buf), + buf.readInt(), + Encoders.Strings.decode(buf), + buf.readInt(), + Encoders.ByteArrays.decode(buf), + Encoders.ByteArrays.decode(buf)); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md new file mode 100644 index 0000000000..14df703270 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md @@ -0,0 +1,158 @@ +Spark Auth Protocol and AES Encryption Support +============================================== + +This file describes an auth protocol used by Spark as a more secure alternative to DIGEST-MD5. This +protocol is built on symmetric key encryption, based on the assumption that the two endpoints being +authenticated share a common secret, which is how Spark authentication currently works. The protocol +provides mutual authentication, meaning that after the negotiation both parties know that the remote +side knows the shared secret. The protocol is influenced by the ISO/IEC 9798 protocol, although it's +not an implementation of it. + +This protocol could be replaced with TLS PSK, except no PSK ciphers are available in the currently +released JREs. + +The protocol aims at solving the following shortcomings in Spark's current usage of DIGEST-MD5: + +- MD5 is an aging hash algorithm with known weaknesses, and a more secure alternative is desired. +- DIGEST-MD5 has a pre-defined set of ciphers for which it can generate keys. The only + viable, supported cipher these days is 3DES, and a more modern alternative is desired. +- Encrypting AES session keys with 3DES doesn't solve the issue, since the weakest link + in the negotiation would still be MD5 and 3DES. + +The protocol assumes that the shared secret is generated and distributed in a secure manner. + +The protocol always negotiates encryption keys. If encryption is not desired, the existing +SASL-based authentication, or no authentication at all, can be chosen instead. + +When messages are described below, it's expected that the implementation should support +arbitrary sizes for fields that don't have a fixed size. + +Client Challenge +---------------- + +The auth negotiation is started by the client. The client starts by generating an encryption +key based on the application's shared secret, and a nonce. + + KEY = KDF(SECRET, SALT, KEY_LENGTH) + +Where: +- KDF(): a key derivation function that takes a secret, a salt, a configurable number of + iterations, and a configurable key length. +- SALT: a byte sequence used to salt the key derivation function. +- KEY_LENGTH: length of the encryption key to generate. + + +The client generates a message with the following content: + + CLIENT_CHALLENGE = ( + APP_ID, + KDF, + ITERATIONS, + CIPHER, + KEY_LENGTH, + ANONCE, + ENC(APP_ID || ANONCE || CHALLENGE)) + +Where: + +- APP_ID: the application ID which the server uses to identify the shared secret. +- KDF: the key derivation function described above. +- ITERATIONS: number of iterations to run the KDF when generating keys. +- CIPHER: the cipher used to encrypt data. +- KEY_LENGTH: length of the encryption keys to generate, in bits. +- ANONCE: the nonce used as the salt when generating the auth key. +- ENC(): an encryption function that uses the cipher and the generated key. This function + will also be used in the definition of other messages below. +- CHALLENGE: a byte sequence used as a challenge to the server. +- ||: concatenation operator. + +When strings are used where byte arrays are expected, the UTF-8 representation of the string +is assumed. + +To respond to the challenge, the server should consider the byte array as representing an +arbitrary-length integer, and respond with the value of the integer plus one. + + +Server Response And Challenge +----------------------------- + +Once the client challenge is received, the server will generate the same auth key by +using the same algorithm the client has used. It will then verify the client challenge: +if the APP_ID and ANONCE fields match, the server knows that the client has the shared +secret. The server then creates a response to the client challenge, to prove that it also +has the secret key, and provides parameters to be used when creating the session key. + +The following describes the response from the server: + + SERVER_CHALLENGE = ( + ENC(APP_ID || ANONCE || RESPONSE), + ENC(SNONCE), + ENC(INIV), + ENC(OUTIV)) + +Where: + +- RESPONSE: the server's response to the client challenge. +- SNONCE: a nonce to be used as salt when generating the session key. +- INIV: initialization vector used to initialize the input channel of the client. +- OUTIV: initialization vector used to initialize the output channel of the client. + +At this point the server considers the client to be authenticated, and will try to +decrypt any data further sent by the client using the session key. + + +Default Algorithms +------------------ + +Configuration options are available for the KDF and cipher algorithms to use. + +The default KDF is "PBKDF2WithHmacSHA1". Users should be able to select any algorithm +from those supported by the `javax.crypto.SecretKeyFactory` class, as long as they support +PBEKeySpec when generating keys. The default number of iterations was chosen to take a +reasonable amount of time on modern CPUs. See the documentation in TransportConf for more +details. + +The default cipher algorithm is "AES/CTR/NoPadding". Users should be able to select any +algorithm supported by the commons-crypto library. It should allow the cipher to operate +in stream mode. + +The default key length is 128 (bits). + + +Implementation Details +---------------------- + +The commons-crypto library currently only supports AES ciphers, and requires an initialization +vector (IV). This first version of the protocol does not explicitly include the IV in the client +challenge message. Instead, the IV should be derived from the nonce, including the needed bytes, and +padding the IV with zeroes in case the nonce is not long enough. + +Future versions of the protocol might add support for new ciphers and explicitly include needed +configuration parameters in the messages. + + +Threat Assessment +----------------- + +The protocol is secure against different forms of attack: + +* Eavesdropping: the protocol is built on the assumption that it's computationally infeasible + to calculate the original secret from the encrypted messages. Neither the secret nor any + encryption keys are transmitted on the wire, encrypted or not. + +* Man-in-the-middle: because the protocol performs mutual authentication, both ends need to + know the shared secret to be able to decrypt session data. Even if an attacker is able to insert a + malicious "proxy" between endpoints, the attacker won't be able to read any of the data exchanged + between client and server, nor insert arbitrary commands for the server to execute. + +* Replay attacks: the use of nonces when generating keys prevents an attacker from being able to + just replay messages sniffed from the communication channel. + +An attacker may replay the client challenge and successfully "prove" to a server that it "knows" the +shared secret. But the attacker won't be able to decrypt the server's response, and thus won't be +able to generate a session key, which will make it hard to craft a valid, encrypted message that the +server will be able to understand. This will cause the server to close the connection as soon as the +attacker tries to send any command to the server. The attacker can just hold the channel open for +some time, which will be closed when the server times out the channel. These issues could be +separately mitigated by adding a shorter timeout for the first message after authentication, and +potentially by adding host blacklists if a possible attack is detected from a particular host. diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java new file mode 100644 index 0000000000..affdbf450b --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** + * Server's response to client's challenge. + * + * @see README.md + */ +public class ServerResponse implements Encodable { + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xFB; + + public final byte[] response; + public final byte[] nonce; + public final byte[] inputIv; + public final byte[] outputIv; + + public ServerResponse( + byte[] response, + byte[] nonce, + byte[] inputIv, + byte[] outputIv) { + this.response = response; + this.nonce = nonce; + this.inputIv = inputIv; + this.outputIv = outputIv; + } + + @Override + public int encodedLength() { + return 1 + + Encoders.ByteArrays.encodedLength(response) + + Encoders.ByteArrays.encodedLength(nonce) + + Encoders.ByteArrays.encodedLength(inputIv) + + Encoders.ByteArrays.encodedLength(outputIv); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.ByteArrays.encode(buf, response); + Encoders.ByteArrays.encode(buf, nonce); + Encoders.ByteArrays.encode(buf, inputIv); + Encoders.ByteArrays.encode(buf, outputIv); + } + + public static ServerResponse decodeMessage(ByteBuffer buffer) { + ByteBuf buf = Unpooled.wrappedBuffer(buffer); + + if (buf.readByte() != TAG_BYTE) { + throw new IllegalArgumentException("Expected ServerResponse, received something else."); + } + + return new ServerResponse( + Encoders.ByteArrays.decode(buf), + Encoders.ByteArrays.decode(buf), + Encoders.ByteArrays.decode(buf), + Encoders.ByteArrays.decode(buf)); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index 340986a63b..7376d1ddc4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.sasl.aes; +package org.apache.spark.network.crypto; import java.io.IOException; import java.nio.ByteBuffer; @@ -25,115 +25,91 @@ import java.util.Properties; import javax.crypto.spec.SecretKeySpec; import javax.crypto.spec.IvParameterSpec; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; -import com.google.common.base.Throwables; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.*; import io.netty.util.AbstractReferenceCounted; -import org.apache.commons.crypto.cipher.CryptoCipherFactory; -import org.apache.commons.crypto.random.CryptoRandom; -import org.apache.commons.crypto.random.CryptoRandomFactory; import org.apache.commons.crypto.stream.CryptoInputStream; import org.apache.commons.crypto.stream.CryptoOutputStream; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.apache.spark.network.util.ByteArrayReadableChannel; import org.apache.spark.network.util.ByteArrayWritableChannel; -import org.apache.spark.network.util.TransportConf; /** - * AES cipher for encryption and decryption. + * Cipher for encryption and decryption. */ -public class AesCipher { - private static final Logger logger = LoggerFactory.getLogger(AesCipher.class); - public static final String ENCRYPTION_HANDLER_NAME = "AesEncryption"; - public static final String DECRYPTION_HANDLER_NAME = "AesDecryption"; - public static final int STREAM_BUFFER_SIZE = 1024 * 32; - public static final String TRANSFORM = "AES/CTR/NoPadding"; - - private final SecretKeySpec inKeySpec; - private final IvParameterSpec inIvSpec; - private final SecretKeySpec outKeySpec; - private final IvParameterSpec outIvSpec; - private final Properties properties; - - public AesCipher(AesConfigMessage configMessage, TransportConf conf) throws IOException { - this.properties = conf.cryptoConf(); - this.inKeySpec = new SecretKeySpec(configMessage.inKey, "AES"); - this.inIvSpec = new IvParameterSpec(configMessage.inIv); - this.outKeySpec = new SecretKeySpec(configMessage.outKey, "AES"); - this.outIvSpec = new IvParameterSpec(configMessage.outIv); +public class TransportCipher { + @VisibleForTesting + static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption"; + private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption"; + private static final int STREAM_BUFFER_SIZE = 1024 * 32; + + private final Properties conf; + private final String cipher; + private final SecretKeySpec key; + private final byte[] inIv; + private final byte[] outIv; + + public TransportCipher( + Properties conf, + String cipher, + SecretKeySpec key, + byte[] inIv, + byte[] outIv) { + this.conf = conf; + this.cipher = cipher; + this.key = key; + this.inIv = inIv; + this.outIv = outIv; + } + + public String getCipherTransformation() { + return cipher; + } + + @VisibleForTesting + SecretKeySpec getKey() { + return key; + } + + /** The IV for the input channel (i.e. output channel of the remote side). */ + public byte[] getInputIv() { + return inIv; + } + + /** The IV for the output channel (i.e. input channel of the remote side). */ + public byte[] getOutputIv() { + return outIv; } - /** - * Create AES crypto output stream - * @param ch The underlying channel to write out. - * @return Return output crypto stream for encryption. - * @throws IOException - */ private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { - return new CryptoOutputStream(TRANSFORM, properties, ch, outKeySpec, outIvSpec); + return new CryptoOutputStream(cipher, conf, ch, key, new IvParameterSpec(outIv)); } - /** - * Create AES crypto input stream - * @param ch The underlying channel used to read data. - * @return Return input crypto stream for decryption. - * @throws IOException - */ private CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException { - return new CryptoInputStream(TRANSFORM, properties, ch, inKeySpec, inIvSpec); + return new CryptoInputStream(cipher, conf, ch, key, new IvParameterSpec(inIv)); } /** - * Add handlers to channel + * Add handlers to channel. + * * @param ch the channel for adding handlers * @throws IOException */ public void addToChannel(Channel ch) throws IOException { ch.pipeline() - .addFirst(ENCRYPTION_HANDLER_NAME, new AesEncryptHandler(this)) - .addFirst(DECRYPTION_HANDLER_NAME, new AesDecryptHandler(this)); - } - - /** - * Create the configuration message - * @param conf is the local transport configuration. - * @return Config message for sending. - */ - public static AesConfigMessage createConfigMessage(TransportConf conf) { - int keySize = conf.aesCipherKeySize(); - Properties properties = conf.cryptoConf(); - - try { - int paramLen = CryptoCipherFactory.getCryptoCipher(AesCipher.TRANSFORM, properties) - .getBlockSize(); - byte[] inKey = new byte[keySize]; - byte[] outKey = new byte[keySize]; - byte[] inIv = new byte[paramLen]; - byte[] outIv = new byte[paramLen]; - - CryptoRandom random = CryptoRandomFactory.getCryptoRandom(properties); - random.nextBytes(inKey); - random.nextBytes(outKey); - random.nextBytes(inIv); - random.nextBytes(outIv); - - return new AesConfigMessage(inKey, inIv, outKey, outIv); - } catch (Exception e) { - logger.error("AES config error", e); - throw Throwables.propagate(e); - } + .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(this)) + .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this)); } - private static class AesEncryptHandler extends ChannelOutboundHandlerAdapter { + private static class EncryptionHandler extends ChannelOutboundHandlerAdapter { private final ByteArrayWritableChannel byteChannel; private final CryptoOutputStream cos; - AesEncryptHandler(AesCipher cipher) throws IOException { - byteChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE); + EncryptionHandler(TransportCipher cipher) throws IOException { + byteChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); cos = cipher.createOutputStream(byteChannel); } @@ -153,11 +129,11 @@ public class AesCipher { } } - private static class AesDecryptHandler extends ChannelInboundHandlerAdapter { + private static class DecryptionHandler extends ChannelInboundHandlerAdapter { private final CryptoInputStream cis; private final ByteArrayReadableChannel byteChannel; - AesDecryptHandler(AesCipher cipher) throws IOException { + DecryptionHandler(TransportCipher cipher) throws IOException { byteChannel = new ByteArrayReadableChannel(); cis = cipher.createInputStream(byteChannel); } @@ -207,7 +183,7 @@ public class AesCipher { this.buf = isByteBuf ? (ByteBuf) msg : null; this.region = isByteBuf ? null : (FileRegion) msg; this.transferred = 0; - this.byteRawChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE); + this.byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); this.cos = cos; this.byteEncChannel = ch; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index a1bb453657..6478137722 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -30,8 +30,6 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.sasl.aes.AesCipher; -import org.apache.spark.network.sasl.aes.AesConfigMessage; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; @@ -42,24 +40,14 @@ import org.apache.spark.network.util.TransportConf; public class SaslClientBootstrap implements TransportClientBootstrap { private static final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); - private final boolean encrypt; private final TransportConf conf; private final String appId; private final SecretKeyHolder secretKeyHolder; public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { - this(conf, appId, secretKeyHolder, false); - } - - public SaslClientBootstrap( - TransportConf conf, - String appId, - SecretKeyHolder secretKeyHolder, - boolean encrypt) { this.conf = conf; this.appId = appId; this.secretKeyHolder = secretKeyHolder; - this.encrypt = encrypt; } /** @@ -69,7 +57,7 @@ public class SaslClientBootstrap implements TransportClientBootstrap { */ @Override public void doBootstrap(TransportClient client, Channel channel) { - SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt); + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption()); try { byte[] payload = saslClient.firstToken(); @@ -79,35 +67,19 @@ public class SaslClientBootstrap implements TransportClientBootstrap { msg.encode(buf); buf.writeBytes(msg.body().nioByteBuffer()); - ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs()); + ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); payload = saslClient.response(JavaUtils.bufferToArray(response)); } client.setClientId(appId); - if (encrypt) { + if (conf.saslEncryption()) { if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) { throw new RuntimeException( new SaslException("Encryption requests by negotiated non-encrypted connection.")); } - if (conf.aesEncryptionEnabled()) { - // Generate a request config message to send to server. - AesConfigMessage configMessage = AesCipher.createConfigMessage(conf); - ByteBuffer buf = configMessage.encodeMessage(); - - // Encrypted the config message. - byte[] toEncrypt = JavaUtils.bufferToArray(buf); - ByteBuffer encrypted = ByteBuffer.wrap(saslClient.wrap(toEncrypt, 0, toEncrypt.length)); - - client.sendRpcSync(encrypted, conf.saslRTTimeoutMs()); - AesCipher cipher = new AesCipher(configMessage, conf); - logger.info("Enabling AES cipher for client channel {}", client); - cipher.addToChannel(channel); - saslClient.dispose(); - } else { - SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); - } + SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); saslClient = null; logger.debug("Channel {} configured for encryption.", client); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index b2f3ef214b..0231428318 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -29,8 +29,6 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.sasl.aes.AesCipher; -import org.apache.spark.network.sasl.aes.AesConfigMessage; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.JavaUtils; @@ -44,7 +42,7 @@ import org.apache.spark.network.util.TransportConf; * Note that the authentication process consists of multiple challenge-response pairs, each of * which are individual RPCs. */ -class SaslRpcHandler extends RpcHandler { +public class SaslRpcHandler extends RpcHandler { private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); /** Transport configuration. */ @@ -63,7 +61,7 @@ class SaslRpcHandler extends RpcHandler { private boolean isComplete; private boolean isAuthenticated; - SaslRpcHandler( + public SaslRpcHandler( TransportConf conf, Channel channel, RpcHandler delegate, @@ -122,37 +120,10 @@ class SaslRpcHandler extends RpcHandler { return; } - if (!conf.aesEncryptionEnabled()) { - logger.debug("Enabling encryption for channel {}", client); - SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); - complete(false); - return; - } - - // Extra negotiation should happen after authentication, so return directly while - // processing authenticate. - if (!isAuthenticated) { - logger.debug("SASL authentication successful for channel {}", client); - isAuthenticated = true; - return; - } - - // Create AES cipher when it is authenticated - try { - byte[] encrypted = JavaUtils.bufferToArray(message); - ByteBuffer decrypted = ByteBuffer.wrap(saslServer.unwrap(encrypted, 0 , encrypted.length)); - - AesConfigMessage configMessage = AesConfigMessage.decodeMessage(decrypted); - AesCipher cipher = new AesCipher(configMessage, conf); - - // Send response back to client to confirm that server accept config. - callback.onSuccess(JavaUtils.stringToBytes(AesCipher.TRANSFORM)); - logger.info("Enabling AES cipher for Server channel {}", client); - cipher.addToChannel(channel); - complete(true); - } catch (IOException ioe) { - throw new RuntimeException(ioe); - } + logger.debug("Enabling encryption for channel {}", client); + SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); + complete(false); + return; } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java deleted file mode 100644 index 3ef6f74a1f..0000000000 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl.aes; - -import java.nio.ByteBuffer; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; - -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.Encoders; - -/** - * The AES cipher options for encryption negotiation. - */ -public class AesConfigMessage implements Encodable { - /** Serialization tag used to catch incorrect payloads. */ - private static final byte TAG_BYTE = (byte) 0xEB; - - public byte[] inKey; - public byte[] outKey; - public byte[] inIv; - public byte[] outIv; - - public AesConfigMessage(byte[] inKey, byte[] inIv, byte[] outKey, byte[] outIv) { - if (inKey == null || inIv == null || outKey == null || outIv == null) { - throw new IllegalArgumentException("Cipher Key or IV must not be null!"); - } - - this.inKey = inKey; - this.inIv = inIv; - this.outKey = outKey; - this.outIv = outIv; - } - - @Override - public int encodedLength() { - return 1 + - Encoders.ByteArrays.encodedLength(inKey) + Encoders.ByteArrays.encodedLength(outKey) + - Encoders.ByteArrays.encodedLength(inIv) + Encoders.ByteArrays.encodedLength(outIv); - } - - @Override - public void encode(ByteBuf buf) { - buf.writeByte(TAG_BYTE); - Encoders.ByteArrays.encode(buf, inKey); - Encoders.ByteArrays.encode(buf, inIv); - Encoders.ByteArrays.encode(buf, outKey); - Encoders.ByteArrays.encode(buf, outIv); - } - - /** - * Encode the config message. - * @return ByteBuffer which contains encoded config message. - */ - public ByteBuffer encodeMessage(){ - ByteBuffer buf = ByteBuffer.allocate(encodedLength()); - - ByteBuf wrappedBuf = Unpooled.wrappedBuffer(buf); - wrappedBuf.clear(); - encode(wrappedBuf); - - return buf; - } - - /** - * Decode the config message from buffer - * @param buffer the buffer contain encoded config message - * @return config message - */ - public static AesConfigMessage decodeMessage(ByteBuffer buffer) { - ByteBuf buf = Unpooled.wrappedBuffer(buffer); - - if (buf.readByte() != TAG_BYTE) { - throw new IllegalStateException("Expected AesConfigMessage, received something else" - + " (maybe your client does not have AES enabled?)"); - } - - byte[] outKey = Encoders.ByteArrays.decode(buf); - byte[] outIv = Encoders.ByteArrays.decode(buf); - byte[] inKey = Encoders.ByteArrays.decode(buf); - byte[] inIv = Encoders.ByteArrays.decode(buf); - return new AesConfigMessage(inKey, inIv, outKey, outIv); - } - -} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 6a557fa75d..c226d8f3bc 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -117,9 +117,10 @@ public class TransportConf { /** Send buffer size (SO_SNDBUF). */ public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); } - /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ - public int saslRTTimeoutMs() { - return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000; + /** Timeout for a single round trip of auth message exchange, in milliseconds. */ + public int authRTTimeoutMs() { + return (int) JavaUtils.timeStringAsSec(conf.get("spark.network.auth.rpcTimeout", + conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s"))) * 1000; } /** @@ -162,40 +163,95 @@ public class TransportConf { } /** - * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled. + * Enables strong encryption. Also enables the new auth protocol, used to negotiate keys. */ - public int maxSaslEncryptedBlockSize() { - return Ints.checkedCast(JavaUtils.byteStringAsBytes( - conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k"))); + public boolean encryptionEnabled() { + return conf.getBoolean("spark.network.crypto.enabled", false); } /** - * Whether the server should enforce encryption on SASL-authenticated connections. + * The cipher transformation to use for encrypting session data. */ - public boolean saslServerAlwaysEncrypt() { - return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); + public String cipherTransformation() { + return conf.get("spark.network.crypto.cipher", "AES/CTR/NoPadding"); + } + + /** + * The key generation algorithm. This should be an algorithm that accepts a "PBEKeySpec" + * as input. The default value (PBKDF2WithHmacSHA1) is available in Java 7. + */ + public String keyFactoryAlgorithm() { + return conf.get("spark.network.crypto.keyFactoryAlgorithm", "PBKDF2WithHmacSHA1"); + } + + /** + * How many iterations to run when generating keys. + * + * See some discussion about this at: http://security.stackexchange.com/q/3959 + * The default value was picked for speed, since it assumes that the secret has good entropy + * (128 bits by default), which is not generally the case with user passwords. + */ + public int keyFactoryIterations() { + return conf.getInt("spark.networy.crypto.keyFactoryIterations", 1024); + } + + /** + * Encryption key length, in bits. + */ + public int encryptionKeyLength() { + return conf.getInt("spark.network.crypto.keyLength", 128); + } + + /** + * Initial vector length, in bytes. + */ + public int ivLength() { + return conf.getInt("spark.network.crypto.ivLength", 16); + } + + /** + * The algorithm for generated secret keys. Nobody should really need to change this, + * but configurable just in case. + */ + public String keyAlgorithm() { + return conf.get("spark.network.crypto.keyAlgorithm", "AES"); + } + + /** + * Whether to fall back to SASL if the new auth protocol fails. Enabled by default for + * backwards compatibility. + */ + public boolean saslFallback() { + return conf.getBoolean("spark.network.crypto.saslFallback", true); } /** - * The trigger for enabling AES encryption. + * Whether to enable SASL-based encryption when authenticating using SASL. */ - public boolean aesEncryptionEnabled() { - return conf.getBoolean("spark.network.aes.enabled", false); + public boolean saslEncryption() { + return conf.getBoolean("spark.authenticate.enableSaslEncryption", false); } /** - * The key size to use when AES cipher is enabled. Notice that the length should be 16, 24 or 32 - * bytes. + * Maximum number of bytes to be encrypted at a time when SASL encryption is used. */ - public int aesCipherKeySize() { - return conf.getInt("spark.network.aes.keySize", 16); + public int maxSaslEncryptedBlockSize() { + return Ints.checkedCast(JavaUtils.byteStringAsBytes( + conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k"))); + } + + /** + * Whether the server should enforce encryption on SASL-authenticated connections. + */ + public boolean saslServerAlwaysEncrypt() { + return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); } /** * The commons-crypto configuration for the module. */ public Properties cryptoConf() { - return CryptoUtils.toCryptoConf("spark.network.aes.config.", conf.getAll()); + return CryptoUtils.toCryptoConf("spark.network.crypto.config.", conf.getAll()); } } |