aboutsummaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
authorJunjie Chen <junjie.j.chen@intel.com>2016-11-11 10:37:58 -0800
committerMarcelo Vanzin <vanzin@cloudera.com>2016-11-11 10:37:58 -0800
commit4f15d94cfec86130f8dab28ae2e228ded8124020 (patch)
tree3ef2fe046d53074f27e7951103a6522718193014 /common
parent5ddf69470b93c0b8a28bb4ac905e7670d9c50a95 (diff)
downloadspark-4f15d94cfec86130f8dab28ae2e228ded8124020.tar.gz
spark-4f15d94cfec86130f8dab28ae2e228ded8124020.tar.bz2
spark-4f15d94cfec86130f8dab28ae2e228ded8124020.zip
[SPARK-13331] AES support for over-the-wire encryption
## What changes were proposed in this pull request? DIGEST-MD5 mechanism is used for SASL authentication and secure communication. DIGEST-MD5 mechanism supports 3DES, DES, and RC4 ciphers. However, 3DES, DES and RC4 are slow relatively. AES provide better performance and security by design and is a replacement for 3DES according to NIST. Apache Common Crypto is a cryptographic library optimized with AES-NI, this patch employ Apache Common Crypto as enc/dec backend for SASL authentication and secure channel to improve spark RPC. ## How was this patch tested? Unit tests and Integration test. Author: Junjie Chen <junjie.j.chen@intel.com> Closes #15172 from cjjnjust/shuffle_rpc_encrypt.
Diffstat (limited to 'common')
-rw-r--r--common/network-common/pom.xml4
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java23
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java101
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java294
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java101
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java62
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java22
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java93
8 files changed, 663 insertions, 37 deletions
diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
index fcefe64d59..ca99fa89eb 100644
--- a/common/network-common/pom.xml
+++ b/common/network-common/pom.xml
@@ -76,6 +76,10 @@
<artifactId>guava</artifactId>
<scope>compile</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-crypto</artifactId>
+ </dependency>
<!-- Test dependencies -->
<dependency>
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index 9e5c616ee5..a1bb453657 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -30,6 +30,8 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.sasl.aes.AesCipher;
+import org.apache.spark.network.sasl.aes.AesConfigMessage;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.TransportConf;
@@ -88,9 +90,26 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
throw new RuntimeException(
new SaslException("Encryption requests by negotiated non-encrypted connection."));
}
- SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
+
+ if (conf.aesEncryptionEnabled()) {
+ // Generate a request config message to send to server.
+ AesConfigMessage configMessage = AesCipher.createConfigMessage(conf);
+ ByteBuffer buf = configMessage.encodeMessage();
+
+ // Encrypted the config message.
+ byte[] toEncrypt = JavaUtils.bufferToArray(buf);
+ ByteBuffer encrypted = ByteBuffer.wrap(saslClient.wrap(toEncrypt, 0, toEncrypt.length));
+
+ client.sendRpcSync(encrypted, conf.saslRTTimeoutMs());
+ AesCipher cipher = new AesCipher(configMessage, conf);
+ logger.info("Enabling AES cipher for client channel {}", client);
+ cipher.addToChannel(channel);
+ saslClient.dispose();
+ } else {
+ SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
+ }
saslClient = null;
- logger.debug("Channel {} configured for SASL encryption.", client);
+ logger.debug("Channel {} configured for encryption.", client);
}
} catch (IOException ioe) {
throw new RuntimeException(ioe);
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index c41f5b6873..b2f3ef214b 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -29,6 +29,8 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.sasl.aes.AesCipher;
+import org.apache.spark.network.sasl.aes.AesConfigMessage;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.JavaUtils;
@@ -59,6 +61,7 @@ class SaslRpcHandler extends RpcHandler {
private SparkSaslServer saslServer;
private boolean isComplete;
+ private boolean isAuthenticated;
SaslRpcHandler(
TransportConf conf,
@@ -71,6 +74,7 @@ class SaslRpcHandler extends RpcHandler {
this.secretKeyHolder = secretKeyHolder;
this.saslServer = null;
this.isComplete = false;
+ this.isAuthenticated = false;
}
@Override
@@ -80,30 +84,31 @@ class SaslRpcHandler extends RpcHandler {
delegate.receive(client, message, callback);
return;
}
+ if (saslServer == null || !saslServer.isComplete()) {
+ ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
+ SaslMessage saslMessage;
+ try {
+ saslMessage = SaslMessage.decode(nettyBuf);
+ } finally {
+ nettyBuf.release();
+ }
- ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
- SaslMessage saslMessage;
- try {
- saslMessage = SaslMessage.decode(nettyBuf);
- } finally {
- nettyBuf.release();
- }
-
- if (saslServer == null) {
- // First message in the handshake, setup the necessary state.
- client.setClientId(saslMessage.appId);
- saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
- conf.saslServerAlwaysEncrypt());
- }
+ if (saslServer == null) {
+ // First message in the handshake, setup the necessary state.
+ client.setClientId(saslMessage.appId);
+ saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
+ conf.saslServerAlwaysEncrypt());
+ }
- byte[] response;
- try {
- response = saslServer.response(JavaUtils.bufferToArray(
- saslMessage.body().nioByteBuffer()));
- } catch (IOException ioe) {
- throw new RuntimeException(ioe);
+ byte[] response;
+ try {
+ response = saslServer.response(JavaUtils.bufferToArray(
+ saslMessage.body().nioByteBuffer()));
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
+ }
+ callback.onSuccess(ByteBuffer.wrap(response));
}
- callback.onSuccess(ByteBuffer.wrap(response));
// Setup encryption after the SASL response is sent, otherwise the client can't parse the
// response. It's ok to change the channel pipeline here since we are processing an incoming
@@ -111,15 +116,42 @@ class SaslRpcHandler extends RpcHandler {
// method returns. This assumes that the code ensures, through other means, that no outbound
// messages are being written to the channel while negotiation is still going on.
if (saslServer.isComplete()) {
- logger.debug("SASL authentication successful for channel {}", client);
- isComplete = true;
- if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
+ if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
+ logger.debug("SASL authentication successful for channel {}", client);
+ complete(true);
+ return;
+ }
+
+ if (!conf.aesEncryptionEnabled()) {
logger.debug("Enabling encryption for channel {}", client);
SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
- saslServer = null;
- } else {
- saslServer.dispose();
- saslServer = null;
+ complete(false);
+ return;
+ }
+
+ // Extra negotiation should happen after authentication, so return directly while
+ // processing authenticate.
+ if (!isAuthenticated) {
+ logger.debug("SASL authentication successful for channel {}", client);
+ isAuthenticated = true;
+ return;
+ }
+
+ // Create AES cipher when it is authenticated
+ try {
+ byte[] encrypted = JavaUtils.bufferToArray(message);
+ ByteBuffer decrypted = ByteBuffer.wrap(saslServer.unwrap(encrypted, 0 , encrypted.length));
+
+ AesConfigMessage configMessage = AesConfigMessage.decodeMessage(decrypted);
+ AesCipher cipher = new AesCipher(configMessage, conf);
+
+ // Send response back to client to confirm that server accept config.
+ callback.onSuccess(JavaUtils.stringToBytes(AesCipher.TRANSFORM));
+ logger.info("Enabling AES cipher for Server channel {}", client);
+ cipher.addToChannel(channel);
+ complete(true);
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
}
}
}
@@ -155,4 +187,17 @@ class SaslRpcHandler extends RpcHandler {
delegate.exceptionCaught(cause, client);
}
+ private void complete(boolean dispose) {
+ if (dispose) {
+ try {
+ saslServer.dispose();
+ } catch (RuntimeException e) {
+ logger.error("Error while disposing SASL server", e);
+ }
+ }
+
+ saslServer = null;
+ isComplete = true;
+ }
+
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java
new file mode 100644
index 0000000000..78034a69f7
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl.aes;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.ReadableByteChannel;
+import java.nio.channels.WritableByteChannel;
+import java.util.Properties;
+import javax.crypto.spec.SecretKeySpec;
+import javax.crypto.spec.IvParameterSpec;
+
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.*;
+import io.netty.util.AbstractReferenceCounted;
+import org.apache.commons.crypto.cipher.CryptoCipherFactory;
+import org.apache.commons.crypto.random.CryptoRandom;
+import org.apache.commons.crypto.random.CryptoRandomFactory;
+import org.apache.commons.crypto.stream.CryptoInputStream;
+import org.apache.commons.crypto.stream.CryptoOutputStream;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.util.ByteArrayReadableChannel;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * AES cipher for encryption and decryption.
+ */
+public class AesCipher {
+ private static final Logger logger = LoggerFactory.getLogger(AesCipher.class);
+ public static final String ENCRYPTION_HANDLER_NAME = "AesEncryption";
+ public static final String DECRYPTION_HANDLER_NAME = "AesDecryption";
+ public static final int STREAM_BUFFER_SIZE = 1024 * 32;
+ public static final String TRANSFORM = "AES/CTR/NoPadding";
+
+ private final SecretKeySpec inKeySpec;
+ private final IvParameterSpec inIvSpec;
+ private final SecretKeySpec outKeySpec;
+ private final IvParameterSpec outIvSpec;
+ private final Properties properties;
+
+ public AesCipher(AesConfigMessage configMessage, TransportConf conf) throws IOException {
+ this.properties = CryptoStreamUtils.toCryptoConf(conf);
+ this.inKeySpec = new SecretKeySpec(configMessage.inKey, "AES");
+ this.inIvSpec = new IvParameterSpec(configMessage.inIv);
+ this.outKeySpec = new SecretKeySpec(configMessage.outKey, "AES");
+ this.outIvSpec = new IvParameterSpec(configMessage.outIv);
+ }
+
+ /**
+ * Create AES crypto output stream
+ * @param ch The underlying channel to write out.
+ * @return Return output crypto stream for encryption.
+ * @throws IOException
+ */
+ private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException {
+ return new CryptoOutputStream(TRANSFORM, properties, ch, outKeySpec, outIvSpec);
+ }
+
+ /**
+ * Create AES crypto input stream
+ * @param ch The underlying channel used to read data.
+ * @return Return input crypto stream for decryption.
+ * @throws IOException
+ */
+ private CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException {
+ return new CryptoInputStream(TRANSFORM, properties, ch, inKeySpec, inIvSpec);
+ }
+
+ /**
+ * Add handlers to channel
+ * @param ch the channel for adding handlers
+ * @throws IOException
+ */
+ public void addToChannel(Channel ch) throws IOException {
+ ch.pipeline()
+ .addFirst(ENCRYPTION_HANDLER_NAME, new AesEncryptHandler(this))
+ .addFirst(DECRYPTION_HANDLER_NAME, new AesDecryptHandler(this));
+ }
+
+ /**
+ * Create the configuration message
+ * @param conf is the local transport configuration.
+ * @return Config message for sending.
+ */
+ public static AesConfigMessage createConfigMessage(TransportConf conf) {
+ int keySize = conf.aesCipherKeySize();
+ Properties properties = CryptoStreamUtils.toCryptoConf(conf);
+
+ try {
+ int paramLen = CryptoCipherFactory.getCryptoCipher(AesCipher.TRANSFORM, properties)
+ .getBlockSize();
+ byte[] inKey = new byte[keySize];
+ byte[] outKey = new byte[keySize];
+ byte[] inIv = new byte[paramLen];
+ byte[] outIv = new byte[paramLen];
+
+ CryptoRandom random = CryptoRandomFactory.getCryptoRandom(properties);
+ random.nextBytes(inKey);
+ random.nextBytes(outKey);
+ random.nextBytes(inIv);
+ random.nextBytes(outIv);
+
+ return new AesConfigMessage(inKey, inIv, outKey, outIv);
+ } catch (Exception e) {
+ logger.error("AES config error", e);
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * CryptoStreamUtils is used to convert config from TransportConf to AES Crypto config.
+ */
+ private static class CryptoStreamUtils {
+ public static Properties toCryptoConf(TransportConf conf) {
+ Properties props = new Properties();
+ if (conf.aesCipherClass() != null) {
+ props.setProperty(CryptoCipherFactory.CLASSES_KEY, conf.aesCipherClass());
+ }
+ return props;
+ }
+ }
+
+ private static class AesEncryptHandler extends ChannelOutboundHandlerAdapter {
+ private final ByteArrayWritableChannel byteChannel;
+ private final CryptoOutputStream cos;
+
+ AesEncryptHandler(AesCipher cipher) throws IOException {
+ byteChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE);
+ cos = cipher.createOutputStream(byteChannel);
+ }
+
+ @Override
+ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
+ throws Exception {
+ ctx.write(new EncryptedMessage(cos, msg, byteChannel), promise);
+ }
+
+ @Override
+ public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
+ try {
+ cos.close();
+ } finally {
+ super.close(ctx, promise);
+ }
+ }
+ }
+
+ private static class AesDecryptHandler extends ChannelInboundHandlerAdapter {
+ private final CryptoInputStream cis;
+ private final ByteArrayReadableChannel byteChannel;
+
+ AesDecryptHandler(AesCipher cipher) throws IOException {
+ byteChannel = new ByteArrayReadableChannel();
+ cis = cipher.createInputStream(byteChannel);
+ }
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
+ byteChannel.feedData((ByteBuf) data);
+
+ byte[] decryptedData = new byte[byteChannel.readableBytes()];
+ int offset = 0;
+ while (offset < decryptedData.length) {
+ offset += cis.read(decryptedData, offset, decryptedData.length - offset);
+ }
+
+ ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length));
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ try {
+ cis.close();
+ } finally {
+ super.channelInactive(ctx);
+ }
+ }
+ }
+
+ private static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion {
+ private final boolean isByteBuf;
+ private final ByteBuf buf;
+ private final FileRegion region;
+ private long transferred;
+ private CryptoOutputStream cos;
+
+ // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has
+ // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data
+ // from upper handler, another is used to store encrypted data.
+ private ByteArrayWritableChannel byteEncChannel;
+ private ByteArrayWritableChannel byteRawChannel;
+
+ private ByteBuffer currentEncrypted;
+
+ EncryptedMessage(CryptoOutputStream cos, Object msg, ByteArrayWritableChannel ch) {
+ Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion,
+ "Unrecognized message type: %s", msg.getClass().getName());
+ this.isByteBuf = msg instanceof ByteBuf;
+ this.buf = isByteBuf ? (ByteBuf) msg : null;
+ this.region = isByteBuf ? null : (FileRegion) msg;
+ this.transferred = 0;
+ this.byteRawChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE);
+ this.cos = cos;
+ this.byteEncChannel = ch;
+ }
+
+ @Override
+ public long count() {
+ return isByteBuf ? buf.readableBytes() : region.count();
+ }
+
+ @Override
+ public long position() {
+ return 0;
+ }
+
+ @Override
+ public long transfered() {
+ return transferred;
+ }
+
+ @Override
+ public long transferTo(WritableByteChannel target, long position) throws IOException {
+ Preconditions.checkArgument(position == transfered(), "Invalid position.");
+
+ do {
+ if (currentEncrypted == null) {
+ encryptMore();
+ }
+
+ int bytesWritten = currentEncrypted.remaining();
+ target.write(currentEncrypted);
+ bytesWritten -= currentEncrypted.remaining();
+ transferred += bytesWritten;
+ if (!currentEncrypted.hasRemaining()) {
+ currentEncrypted = null;
+ byteEncChannel.reset();
+ }
+ } while (transferred < count());
+
+ return transferred;
+ }
+
+ private void encryptMore() throws IOException {
+ byteRawChannel.reset();
+
+ if (isByteBuf) {
+ int copied = byteRawChannel.write(buf.nioBuffer());
+ buf.skipBytes(copied);
+ } else {
+ region.transferTo(byteRawChannel, region.transfered());
+ }
+ cos.write(byteRawChannel.getData(), 0, byteRawChannel.length());
+ cos.flush();
+
+ currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(),
+ 0, byteEncChannel.length());
+ }
+
+ @Override
+ protected void deallocate() {
+ byteRawChannel.reset();
+ byteEncChannel.reset();
+ if (region != null) {
+ region.release();
+ }
+ if (buf != null) {
+ buf.release();
+ }
+ }
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java
new file mode 100644
index 0000000000..3ef6f74a1f
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl.aes;
+
+import java.nio.ByteBuffer;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.Encoders;
+
+/**
+ * The AES cipher options for encryption negotiation.
+ */
+public class AesConfigMessage implements Encodable {
+ /** Serialization tag used to catch incorrect payloads. */
+ private static final byte TAG_BYTE = (byte) 0xEB;
+
+ public byte[] inKey;
+ public byte[] outKey;
+ public byte[] inIv;
+ public byte[] outIv;
+
+ public AesConfigMessage(byte[] inKey, byte[] inIv, byte[] outKey, byte[] outIv) {
+ if (inKey == null || inIv == null || outKey == null || outIv == null) {
+ throw new IllegalArgumentException("Cipher Key or IV must not be null!");
+ }
+
+ this.inKey = inKey;
+ this.inIv = inIv;
+ this.outKey = outKey;
+ this.outIv = outIv;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 1 +
+ Encoders.ByteArrays.encodedLength(inKey) + Encoders.ByteArrays.encodedLength(outKey) +
+ Encoders.ByteArrays.encodedLength(inIv) + Encoders.ByteArrays.encodedLength(outIv);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeByte(TAG_BYTE);
+ Encoders.ByteArrays.encode(buf, inKey);
+ Encoders.ByteArrays.encode(buf, inIv);
+ Encoders.ByteArrays.encode(buf, outKey);
+ Encoders.ByteArrays.encode(buf, outIv);
+ }
+
+ /**
+ * Encode the config message.
+ * @return ByteBuffer which contains encoded config message.
+ */
+ public ByteBuffer encodeMessage(){
+ ByteBuffer buf = ByteBuffer.allocate(encodedLength());
+
+ ByteBuf wrappedBuf = Unpooled.wrappedBuffer(buf);
+ wrappedBuf.clear();
+ encode(wrappedBuf);
+
+ return buf;
+ }
+
+ /**
+ * Decode the config message from buffer
+ * @param buffer the buffer contain encoded config message
+ * @return config message
+ */
+ public static AesConfigMessage decodeMessage(ByteBuffer buffer) {
+ ByteBuf buf = Unpooled.wrappedBuffer(buffer);
+
+ if (buf.readByte() != TAG_BYTE) {
+ throw new IllegalStateException("Expected AesConfigMessage, received something else"
+ + " (maybe your client does not have AES enabled?)");
+ }
+
+ byte[] outKey = Encoders.ByteArrays.decode(buf);
+ byte[] outIv = Encoders.ByteArrays.decode(buf);
+ byte[] inKey = Encoders.ByteArrays.decode(buf);
+ byte[] inIv = Encoders.ByteArrays.decode(buf);
+ return new AesConfigMessage(inKey, inIv, outKey, outIv);
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java
new file mode 100644
index 0000000000..25d103d0e3
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.ReadableByteChannel;
+
+import io.netty.buffer.ByteBuf;
+
+public class ByteArrayReadableChannel implements ReadableByteChannel {
+ private ByteBuf data;
+
+ public int readableBytes() {
+ return data.readableBytes();
+ }
+
+ public void feedData(ByteBuf buf) {
+ data = buf;
+ }
+
+ @Override
+ public int read(ByteBuffer dst) throws IOException {
+ int totalRead = 0;
+ while (data.readableBytes() > 0 && dst.remaining() > 0) {
+ int bytesToRead = Math.min(data.readableBytes(), dst.remaining());
+ dst.put(data.readSlice(bytesToRead).nioBuffer());
+ totalRead += bytesToRead;
+ }
+
+ if (data.readableBytes() == 0) {
+ data.release();
+ }
+
+ return totalRead;
+ }
+
+ @Override
+ public void close() throws IOException {
+ }
+
+ @Override
+ public boolean isOpen() {
+ return true;
+ }
+
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 64eaba103c..d0d072849d 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -18,6 +18,7 @@
package org.apache.spark.network.util;
import com.google.common.primitives.Ints;
+import org.apache.commons.crypto.cipher.CryptoCipherFactory;
/**
* A central location that tracks all the settings we expose to users.
@@ -175,4 +176,25 @@ public class TransportConf {
return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
}
+ /**
+ * The trigger for enabling AES encryption.
+ */
+ public boolean aesEncryptionEnabled() {
+ return conf.getBoolean("spark.authenticate.encryption.aes.enabled", false);
+ }
+
+ /**
+ * The implementation class for crypto cipher
+ */
+ public String aesCipherClass() {
+ return conf.get("spark.authenticate.encryption.aes.cipher.class", null);
+ }
+
+ /**
+ * The bytes of AES cipher key which is effective when AES cipher is enabled. Notice that
+ * the length should be 16, 24 or 32 bytes.
+ */
+ public int aesCipherKeySize() {
+ return conf.getInt("spark.authenticate.encryption.aes.cipher.keySize", 16);
+ }
}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index 45cc03df43..4e6146cf07 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -53,6 +53,7 @@ import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.sasl.aes.AesCipher;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
@@ -149,7 +150,7 @@ public class SparkSaslSuite {
.when(rpcHandler)
.receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class));
- SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
+ SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false, false);
try {
ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
TimeUnit.SECONDS.toMillis(10));
@@ -275,7 +276,7 @@ public class SparkSaslSuite {
new Random().nextBytes(data);
Files.write(data, file);
- ctx = new SaslTestCtx(rpcHandler, true, false);
+ ctx = new SaslTestCtx(rpcHandler, true, false, false);
final CountDownLatch lock = new CountDownLatch(1);
@@ -317,7 +318,7 @@ public class SparkSaslSuite {
SaslTestCtx ctx = null;
try {
- ctx = new SaslTestCtx(mock(RpcHandler.class), false, false);
+ ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, false);
fail("Should have failed to connect without encryption.");
} catch (Exception e) {
assertTrue(e.getCause() instanceof SaslException);
@@ -336,7 +337,7 @@ public class SparkSaslSuite {
// able to understand RPCs sent to it and thus close the connection.
SaslTestCtx ctx = null;
try {
- ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
+ ctx = new SaslTestCtx(mock(RpcHandler.class), true, true, false);
ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
TimeUnit.SECONDS.toMillis(10));
fail("Should have failed to send RPC to server.");
@@ -374,6 +375,69 @@ public class SparkSaslSuite {
}
}
+ @Test
+ public void testAesEncryption() throws Exception {
+ final AtomicReference<ManagedBuffer> response = new AtomicReference<>();
+ final File file = File.createTempFile("sasltest", ".txt");
+ SaslTestCtx ctx = null;
+ try {
+ final TransportConf conf = new TransportConf("rpc", new SystemPropertyConfigProvider());
+ final TransportConf spyConf = spy(conf);
+ doReturn(true).when(spyConf).aesEncryptionEnabled();
+
+ StreamManager sm = mock(StreamManager.class);
+ when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer<ManagedBuffer>() {
+ @Override
+ public ManagedBuffer answer(InvocationOnMock invocation) {
+ return new FileSegmentManagedBuffer(spyConf, file, 0, file.length());
+ }
+ });
+
+ RpcHandler rpcHandler = mock(RpcHandler.class);
+ when(rpcHandler.getStreamManager()).thenReturn(sm);
+
+ byte[] data = new byte[256 * 1024 * 1024];
+ new Random().nextBytes(data);
+ Files.write(data, file);
+
+ ctx = new SaslTestCtx(rpcHandler, true, false, true);
+
+ final Object lock = new Object();
+
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocation) {
+ response.set((ManagedBuffer) invocation.getArguments()[1]);
+ response.get().retain();
+ synchronized (lock) {
+ lock.notifyAll();
+ }
+ return null;
+ }
+ }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
+
+ synchronized (lock) {
+ ctx.client.fetchChunk(0, 0, callback);
+ lock.wait(10 * 1000);
+ }
+
+ verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
+ verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
+
+ byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
+ assertTrue(Arrays.equals(data, received));
+ } finally {
+ file.delete();
+ if (ctx != null) {
+ ctx.close();
+ }
+ if (response.get() != null) {
+ response.get().release();
+ }
+ }
+ }
+
private static class SaslTestCtx {
final TransportClient client;
@@ -386,18 +450,28 @@ public class SparkSaslSuite {
SaslTestCtx(
RpcHandler rpcHandler,
boolean encrypt,
- boolean disableClientEncryption)
+ boolean disableClientEncryption,
+ boolean aesEnable)
throws Exception {
TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+ if (aesEnable) {
+ conf = spy(conf);
+ doReturn(true).when(conf).aesEncryptionEnabled();
+ }
+
SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
when(keyHolder.getSaslUser(anyString())).thenReturn("user");
when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
TransportContext ctx = new TransportContext(conf, rpcHandler);
- this.checker = new EncryptionCheckerBootstrap();
+ String encryptHandlerName = aesEnable ? AesCipher.ENCRYPTION_HANDLER_NAME :
+ SaslEncryption.ENCRYPTION_HANDLER_NAME;
+
+ this.checker = new EncryptionCheckerBootstrap(encryptHandlerName);
+
this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder),
checker));
@@ -437,13 +511,18 @@ public class SparkSaslSuite {
implements TransportServerBootstrap {
boolean foundEncryptionHandler;
+ String encryptHandlerName;
+
+ public EncryptionCheckerBootstrap(String encryptHandlerName) {
+ this.encryptHandlerName = encryptHandlerName;
+ }
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
if (!foundEncryptionHandler) {
foundEncryptionHandler =
- ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null;
+ ctx.channel().pipeline().get(encryptHandlerName) != null;
}
ctx.write(msg, promise);
}