aboutsummaryrefslogtreecommitdiff
path: root/network
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2015-05-01 19:01:46 -0700
committerReynold Xin <rxin@databricks.com>2015-05-01 19:01:46 -0700
commit38d4e9e446b425ca6a8fe8d8080f387b08683842 (patch)
tree091ca906d1281ce5f95772f1748ed79bd44ca655 /network
parent8f50a07d2188ccc5315d979755188b1e5d5b5471 (diff)
downloadspark-38d4e9e446b425ca6a8fe8d8080f387b08683842.tar.gz
spark-38d4e9e446b425ca6a8fe8d8080f387b08683842.tar.bz2
spark-38d4e9e446b425ca6a8fe8d8080f387b08683842.zip
[SPARK-6229] Add SASL encryption to network library.
There are two main parts of this change: - Extending the bootstrap mechanism in the network library to add a server-side bootstrap (which works a little bit differently than the client-side bootstrap), and to allow the bootstraps to modify the underlying channel. - Use SASL to encrypt data going through the RPC channel. The second item requires some non-optimal code to be able to work around the fact that the outbound path in netty is not thread-safe, and ordering is very important when encryption is in the picture. A lot of the changes outside the network/common library are just to adjust to the changed API for initializing the RPC server. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #5377 from vanzin/SPARK-6229 and squashes the following commits: ff01966 [Marcelo Vanzin] Use fancy new size config style. be53f32 [Marcelo Vanzin] Merge branch 'master' into SPARK-6229 47d4aff [Marcelo Vanzin] Merge branch 'master' into SPARK-6229 7a2a805 [Marcelo Vanzin] Clean up some unneeded changes. 2f92237 [Marcelo Vanzin] Add comment. 67bb0c6 [Marcelo Vanzin] Revert "Avoid exposing ByteArrayWritableChannel outside of test code." 065f684 [Marcelo Vanzin] Add test to verify chunking. 3d1695d [Marcelo Vanzin] Minor cleanups. 73cff0e [Marcelo Vanzin] Skip bytes in decode path too. 318ad23 [Marcelo Vanzin] Avoid exposing ByteArrayWritableChannel outside of test code. 346f829 [Marcelo Vanzin] Avoid trip through channel selector by not reporting 0 bytes written. a4a5938 [Marcelo Vanzin] Review feedback. 4797519 [Marcelo Vanzin] Remove unused import. 9908ada [Marcelo Vanzin] Fix test, SASL backend disposal. 7fe1489 [Marcelo Vanzin] Add a test that makes sure encryption is actually enabled. adb6f9d [Marcelo Vanzin] Review feedback. cf2a605 [Marcelo Vanzin] Clean up some code. 8584323 [Marcelo Vanzin] Fix a comment. e98bc55 [Marcelo Vanzin] Add option to only allow encrypted connections to the server. dad42fc [Marcelo Vanzin] Make encryption thread-safe, less memory-intensive. b00999a [Marcelo Vanzin] Consolidate ByteArrayWritableChannel, fix SASL code to match master changes. b923cae [Marcelo Vanzin] Make SASL encryption handler thread-safe, handle FileRegion messages. 39539a7 [Marcelo Vanzin] Add config option to enable SASL encryption. 351a86f [Marcelo Vanzin] Add SASL encryption to network library. fbe6ccb [Marcelo Vanzin] Add TransportServerBootstrap, make SASL code use it.
Diffstat (limited to 'network')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/TransportContext.java26
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java4
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java5
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java41
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java291
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java33
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java56
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java49
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java33
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java49
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportServer.java19
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java36
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java (renamed from network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java)26
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportConf.java18
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java1
-rw-r--r--network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java2
-rw-r--r--network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java358
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java11
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java9
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java4
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java27
-rw-r--r--network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java15
22 files changed, 1029 insertions, 84 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index 3fe69b1bd8..b8d073fa16 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -36,6 +36,7 @@ import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.server.TransportRequestHandler;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;
@@ -82,13 +83,21 @@ public class TransportContext {
}
/** Create a server which will attempt to bind to a specific port. */
- public TransportServer createServer(int port) {
- return new TransportServer(this, port);
+ public TransportServer createServer(int port, List<TransportServerBootstrap> bootstraps) {
+ return new TransportServer(this, port, rpcHandler, bootstraps);
}
/** Creates a new server, binding to any available ephemeral port. */
+ public TransportServer createServer(List<TransportServerBootstrap> bootstraps) {
+ return createServer(0, bootstraps);
+ }
+
public TransportServer createServer() {
- return new TransportServer(this, 0);
+ return createServer(0, Lists.<TransportServerBootstrap>newArrayList());
+ }
+
+ public TransportChannelHandler initializePipeline(SocketChannel channel) {
+ return initializePipeline(channel, rpcHandler);
}
/**
@@ -96,13 +105,18 @@ public class TransportContext {
* has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or
* response messages.
*
+ * @param channel The channel to initialize.
+ * @param channelRpcHandler The RPC handler to use for the channel.
+ *
* @return Returns the created TransportChannelHandler, which includes a TransportClient that can
* be used to communicate on this channel. The TransportClient is directly associated with a
* ChannelHandler to ensure all users of the same channel get the same TransportClient object.
*/
- public TransportChannelHandler initializePipeline(SocketChannel channel) {
+ public TransportChannelHandler initializePipeline(
+ SocketChannel channel,
+ RpcHandler channelRpcHandler) {
try {
- TransportChannelHandler channelHandler = createChannelHandler(channel);
+ TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
channel.pipeline()
.addLast("encoder", encoder)
.addLast("frameDecoder", NettyUtils.createFrameDecoder())
@@ -123,7 +137,7 @@ public class TransportContext {
* ResponseMessages. The channel is expected to have been successfully created, though certain
* properties (such as the remoteAddress()) may not be available yet.
*/
- private TransportChannelHandler createChannelHandler(Channel channel) {
+ private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) {
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
TransportClient client = new TransportClient(channel, responseHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
index 65e8020e34..eaae2ee043 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.client;
+import io.netty.channel.Channel;
+
/**
* A bootstrap which is executed on a TransportClient before it is returned to the user.
* This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per-
@@ -28,5 +30,5 @@ package org.apache.spark.network.client;
*/
public interface TransportClientBootstrap {
/** Performs the bootstrapping operation, throwing an exception on failure. */
- public void doBootstrap(TransportClient client) throws RuntimeException;
+ void doBootstrap(TransportClient client, Channel channel) throws RuntimeException;
}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index d26b9b4d60..4952ffb44b 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -172,12 +172,14 @@ public class TransportClientFactory implements Closeable {
.option(ChannelOption.ALLOCATOR, pooledAllocator);
final AtomicReference<TransportClient> clientRef = new AtomicReference<TransportClient>();
+ final AtomicReference<Channel> channelRef = new AtomicReference<Channel>();
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {
TransportChannelHandler clientHandler = context.initializePipeline(ch);
clientRef.set(clientHandler.getClient());
+ channelRef.set(ch);
}
});
@@ -192,6 +194,7 @@ public class TransportClientFactory implements Closeable {
}
TransportClient client = clientRef.get();
+ Channel channel = channelRef.get();
assert client != null : "Channel future completed successfully with null client";
// Execute any client bootstraps synchronously before marking the Client as successful.
@@ -199,7 +202,7 @@ public class TransportClientFactory implements Closeable {
logger.debug("Connection to {} successful, running bootstraps...", address);
try {
for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
- clientBootstrap.doBootstrap(client);
+ clientBootstrap.doBootstrap(client, channel);
}
} catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index 33aa134434..185ba2ef3b 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -17,8 +17,12 @@
package org.apache.spark.network.sasl;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -33,14 +37,24 @@ import org.apache.spark.network.util.TransportConf;
public class SaslClientBootstrap implements TransportClientBootstrap {
private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);
+ private final boolean encrypt;
private final TransportConf conf;
private final String appId;
private final SecretKeyHolder secretKeyHolder;
public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
+ this(conf, appId, secretKeyHolder, false);
+ }
+
+ public SaslClientBootstrap(
+ TransportConf conf,
+ String appId,
+ SecretKeyHolder secretKeyHolder,
+ boolean encrypt) {
this.conf = conf;
this.appId = appId;
this.secretKeyHolder = secretKeyHolder;
+ this.encrypt = encrypt;
}
/**
@@ -49,8 +63,8 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
* due to mismatch.
*/
@Override
- public void doBootstrap(TransportClient client) {
- SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder);
+ public void doBootstrap(TransportClient client, Channel channel) {
+ SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt);
try {
byte[] payload = saslClient.firstToken();
@@ -62,13 +76,26 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs());
payload = saslClient.response(response);
}
+
+ if (encrypt) {
+ if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
+ throw new RuntimeException(
+ new SaslException("Encryption requests by negotiated non-encrypted connection."));
+ }
+ SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
+ saslClient = null;
+ logger.debug("Channel {} configured for SASL encryption.", client);
+ }
} finally {
- try {
- // Once authentication is complete, the server will trust all remaining communication.
- saslClient.dispose();
- } catch (RuntimeException e) {
- logger.error("Error while disposing SASL client", e);
+ if (saslClient != null) {
+ try {
+ // Once authentication is complete, the server will trust all remaining communication.
+ saslClient.dispose();
+ } catch (RuntimeException e) {
+ logger.error("Error while disposing SASL client", e);
+ }
}
}
}
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java
new file mode 100644
index 0000000000..127335e4d3
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+import java.util.List;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPromise;
+import io.netty.channel.FileRegion;
+import io.netty.handler.codec.MessageToMessageDecoder;
+import io.netty.util.AbstractReferenceCounted;
+import io.netty.util.ReferenceCountUtil;
+
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * Provides SASL-based encription for transport channels. The single method exposed by this
+ * class installs the needed channel handlers on a connected channel.
+ */
+class SaslEncryption {
+
+ @VisibleForTesting
+ static final String ENCRYPTION_HANDLER_NAME = "saslEncryption";
+
+ /**
+ * Adds channel handlers that perform encryption / decryption of data using SASL.
+ *
+ * @param channel The channel.
+ * @param backend The SASL backend.
+ * @param maxOutboundBlockSize Max size in bytes of outgoing encrypted blocks, to control
+ * memory usage.
+ */
+ static void addToChannel(
+ Channel channel,
+ SaslEncryptionBackend backend,
+ int maxOutboundBlockSize) {
+ channel.pipeline()
+ .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize))
+ .addFirst("saslDecryption", new DecryptionHandler(backend))
+ .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder());
+ }
+
+ private static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
+
+ private final int maxOutboundBlockSize;
+ private final SaslEncryptionBackend backend;
+
+ EncryptionHandler(SaslEncryptionBackend backend, int maxOutboundBlockSize) {
+ this.backend = backend;
+ this.maxOutboundBlockSize = maxOutboundBlockSize;
+ }
+
+ /**
+ * Wrap the incoming message in an implementation that will perform encryption lazily. This is
+ * needed to guarantee ordering of the outgoing encrypted packets - they need to be decrypted in
+ * the same order, and netty doesn't have an atomic ChannelHandlerContext.write() API, so it
+ * does not guarantee any ordering.
+ */
+ @Override
+ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
+ throws Exception {
+
+ ctx.write(new EncryptedMessage(backend, msg, maxOutboundBlockSize), promise);
+ }
+
+ @Override
+ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
+ try {
+ backend.dispose();
+ } finally {
+ super.handlerRemoved(ctx);
+ }
+ }
+
+ }
+
+ private static class DecryptionHandler extends MessageToMessageDecoder<ByteBuf> {
+
+ private final SaslEncryptionBackend backend;
+
+ DecryptionHandler(SaslEncryptionBackend backend) {
+ this.backend = backend;
+ }
+
+ @Override
+ protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out)
+ throws Exception {
+
+ byte[] data;
+ int offset;
+ int length = msg.readableBytes();
+ if (msg.hasArray()) {
+ data = msg.array();
+ offset = msg.arrayOffset();
+ msg.skipBytes(length);
+ } else {
+ data = new byte[length];
+ msg.readBytes(data);
+ offset = 0;
+ }
+
+ out.add(Unpooled.wrappedBuffer(backend.unwrap(data, offset, length)));
+ }
+
+ }
+
+ @VisibleForTesting
+ static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion {
+
+ private final SaslEncryptionBackend backend;
+ private final boolean isByteBuf;
+ private final ByteBuf buf;
+ private final FileRegion region;
+
+ /**
+ * A channel used to buffer input data for encryption. The channel has an upper size bound
+ * so that if the input is larger than the allowed buffer, it will be broken into multiple
+ * chunks.
+ */
+ private final ByteArrayWritableChannel byteChannel;
+
+ private ByteBuf currentHeader;
+ private ByteBuffer currentChunk;
+ private long currentChunkSize;
+ private long currentReportedBytes;
+ private long unencryptedChunkSize;
+ private long transferred;
+
+ EncryptedMessage(SaslEncryptionBackend backend, Object msg, int maxOutboundBlockSize) {
+ Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion,
+ "Unrecognized message type: %s", msg.getClass().getName());
+ this.backend = backend;
+ this.isByteBuf = msg instanceof ByteBuf;
+ this.buf = isByteBuf ? (ByteBuf) msg : null;
+ this.region = isByteBuf ? null : (FileRegion) msg;
+ this.byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize);
+ }
+
+ /**
+ * Returns the size of the original (unencrypted) message.
+ *
+ * This makes assumptions about how netty treats FileRegion instances, because there's no way
+ * to know beforehand what will be the size of the encrypted message. Namely, it assumes
+ * that netty will try to transfer data from this message while
+ * <code>transfered() < count()</code>. So these two methods return, technically, wrong data,
+ * but netty doesn't know better.
+ */
+ @Override
+ public long count() {
+ return isByteBuf ? buf.readableBytes() : region.count();
+ }
+
+ @Override
+ public long position() {
+ return 0;
+ }
+
+ /**
+ * Returns an approximation of the amount of data transferred. See {@link #count()}.
+ */
+ @Override
+ public long transfered() {
+ return transferred;
+ }
+
+ /**
+ * Transfers data from the original message to the channel, encrypting it in the process.
+ *
+ * This method also breaks down the original message into smaller chunks when needed. This
+ * is done to keep memory usage under control. This avoids having to copy the whole message
+ * data into memory at once, and can avoid ballooning memory usage when transferring large
+ * messages such as shuffle blocks.
+ *
+ * The {@link #transfered()} counter also behaves a little funny, in that it won't go forward
+ * until a whole chunk has been written. This is done because the code can't use the actual
+ * number of bytes written to the channel as the transferred count (see {@link #count()}).
+ * Instead, once an encrypted chunk is written to the output (including its header), the
+ * size of the original block will be added to the {@link #transfered()} amount.
+ */
+ @Override
+ public long transferTo(final WritableByteChannel target, final long position)
+ throws IOException {
+
+ Preconditions.checkArgument(position == transfered(), "Invalid position.");
+
+ long reportedWritten = 0L;
+ long actuallyWritten = 0L;
+ do {
+ if (currentChunk == null) {
+ nextChunk();
+ }
+
+ if (currentHeader.readableBytes() > 0) {
+ int bytesWritten = target.write(currentHeader.nioBuffer());
+ currentHeader.skipBytes(bytesWritten);
+ actuallyWritten += bytesWritten;
+ if (currentHeader.readableBytes() > 0) {
+ // Break out of loop if there are still header bytes left to write.
+ break;
+ }
+ }
+
+ actuallyWritten += target.write(currentChunk);
+ if (!currentChunk.hasRemaining()) {
+ // Only update the count of written bytes once a full chunk has been written.
+ // See method javadoc.
+ long chunkBytesRemaining = unencryptedChunkSize - currentReportedBytes;
+ reportedWritten += chunkBytesRemaining;
+ transferred += chunkBytesRemaining;
+ currentHeader.release();
+ currentHeader = null;
+ currentChunk = null;
+ currentChunkSize = 0;
+ currentReportedBytes = 0;
+ }
+ } while (currentChunk == null && transfered() + reportedWritten < count());
+
+ // Returning 0 triggers a backoff mechanism in netty which may harm performance. Instead,
+ // we return 1 until we can (i.e. until the reported count would actually match the size
+ // of the current chunk), at which point we resort to returning 0 so that the counts still
+ // match, at the cost of some performance. That situation should be rare, though.
+ if (reportedWritten != 0L) {
+ return reportedWritten;
+ }
+
+ if (actuallyWritten > 0 && currentReportedBytes < currentChunkSize - 1) {
+ transferred += 1L;
+ currentReportedBytes += 1L;
+ return 1L;
+ }
+
+ return 0L;
+ }
+
+ private void nextChunk() throws IOException {
+ byteChannel.reset();
+ if (isByteBuf) {
+ int copied = byteChannel.write(buf.nioBuffer());
+ buf.skipBytes(copied);
+ } else {
+ region.transferTo(byteChannel, region.transfered());
+ }
+
+ byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length());
+ this.currentChunk = ByteBuffer.wrap(encrypted);
+ this.currentChunkSize = encrypted.length;
+ this.currentHeader = Unpooled.copyLong(8 + currentChunkSize);
+ this.unencryptedChunkSize = byteChannel.length();
+ }
+
+ @Override
+ protected void deallocate() {
+ if (currentHeader != null) {
+ currentHeader.release();
+ }
+ if (buf != null) {
+ buf.release();
+ }
+ if (region != null) {
+ region.release();
+ }
+ }
+
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java
new file mode 100644
index 0000000000..89b78bc7e1
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.sasl.SaslException;
+
+interface SaslEncryptionBackend {
+
+ /** Disposes of resources used by the backend. */
+ void dispose();
+
+ /** Encrypt data. */
+ byte[] wrap(byte[] data, int offset, int len) throws SaslException;
+
+ /** Decrypt data. */
+ byte[] unwrap(byte[] data, int offset, int len) throws SaslException;
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index 026cbd260d..be6165caf3 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -17,10 +17,10 @@
package org.apache.spark.network.sasl;
-import java.util.concurrent.ConcurrentMap;
+import javax.security.sasl.Sasl;
-import com.google.common.collect.Maps;
import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -28,6 +28,7 @@ import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.TransportConf;
/**
* RPC Handler which performs SASL authentication before delegating to a child RPC handler.
@@ -37,8 +38,14 @@ import org.apache.spark.network.server.StreamManager;
* Note that the authentication process consists of multiple challenge-response pairs, each of
* which are individual RPCs.
*/
-public class SaslRpcHandler extends RpcHandler {
- private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
+class SaslRpcHandler extends RpcHandler {
+ private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
+
+ /** Transport configuration. */
+ private final TransportConf conf;
+
+ /** The client channel. */
+ private final Channel channel;
/** RpcHandler we will delegate to for authenticated connections. */
private final RpcHandler delegate;
@@ -46,19 +53,25 @@ public class SaslRpcHandler extends RpcHandler {
/** Class which provides secret keys which are shared by server and client on a per-app basis. */
private final SecretKeyHolder secretKeyHolder;
- /** Maps each channel to its SASL authentication state. */
- private final ConcurrentMap<TransportClient, SparkSaslServer> channelAuthenticationMap;
+ private SparkSaslServer saslServer;
+ private boolean isComplete;
- public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) {
+ SaslRpcHandler(
+ TransportConf conf,
+ Channel channel,
+ RpcHandler delegate,
+ SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.channel = channel;
this.delegate = delegate;
this.secretKeyHolder = secretKeyHolder;
- this.channelAuthenticationMap = Maps.newConcurrentMap();
+ this.saslServer = null;
+ this.isComplete = false;
}
@Override
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
- SparkSaslServer saslServer = channelAuthenticationMap.get(client);
- if (saslServer != null && saslServer.isComplete()) {
+ if (isComplete) {
// Authentication complete, delegate to base handler.
delegate.receive(client, message, callback);
return;
@@ -68,15 +81,30 @@ public class SaslRpcHandler extends RpcHandler {
if (saslServer == null) {
// First message in the handshake, setup the necessary state.
- saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder);
- channelAuthenticationMap.put(client, saslServer);
+ saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
+ conf.saslServerAlwaysEncrypt());
}
byte[] response = saslServer.response(saslMessage.payload);
+ callback.onSuccess(response);
+
+ // Setup encryption after the SASL response is sent, otherwise the client can't parse the
+ // response. It's ok to change the channel pipeline here since we are processing an incoming
+ // message, so the pipeline is busy and no new incoming messages will be fed to it before this
+ // method returns. This assumes that the code ensures, through other means, that no outbound
+ // messages are being written to the channel while negotiation is still going on.
if (saslServer.isComplete()) {
logger.debug("SASL authentication successful for channel {}", client);
+ isComplete = true;
+ if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
+ logger.debug("Enabling encryption for channel {}", client);
+ SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
+ saslServer = null;
+ } else {
+ saslServer.dispose();
+ saslServer = null;
+ }
}
- callback.onSuccess(response);
}
@Override
@@ -86,9 +114,9 @@ public class SaslRpcHandler extends RpcHandler {
@Override
public void connectionTerminated(TransportClient client) {
- SparkSaslServer saslServer = channelAuthenticationMap.remove(client);
if (saslServer != null) {
saslServer.dispose();
}
}
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java
new file mode 100644
index 0000000000..f2f983856f
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import io.netty.channel.Channel;
+
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a client connects
+ * to the server. This allows customizing the client channel to allow for things such as SASL
+ * authentication.
+ */
+public class SaslServerBootstrap implements TransportServerBootstrap {
+
+ private final TransportConf conf;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public SaslServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ /**
+ * Wrap the given application handler in a SaslRpcHandler that will handle the initial SASL
+ * negotiation.
+ */
+ public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
+ return new SaslRpcHandler(conf, channel, rpcHandler, secretKeyHolder);
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
index 9abad1f30a..94685e91b8 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.sasl;
+import java.io.IOException;
+import java.util.Map;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
@@ -27,9 +29,9 @@ import javax.security.sasl.RealmChoiceCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
-import java.io.IOException;
import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -40,19 +42,25 @@ import static org.apache.spark.network.sasl.SparkSaslServer.*;
* initial state to the "authenticated" state. This client initializes the protocol via a
* firstToken, which is then followed by a set of challenges and responses.
*/
-public class SparkSaslClient {
+public class SparkSaslClient implements SaslEncryptionBackend {
private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class);
private final String secretKeyId;
private final SecretKeyHolder secretKeyHolder;
+ private final String expectedQop;
private SaslClient saslClient;
- public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) {
+ public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, boolean encrypt) {
this.secretKeyId = secretKeyId;
this.secretKeyHolder = secretKeyHolder;
+ this.expectedQop = encrypt ? QOP_AUTH_CONF : QOP_AUTH;
+
+ Map<String, String> saslProps = ImmutableMap.<String, String>builder()
+ .put(Sasl.QOP, expectedQop)
+ .build();
try {
this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM,
- SASL_PROPS, new ClientCallbackHandler());
+ saslProps, new ClientCallbackHandler());
} catch (SaslException e) {
throw Throwables.propagate(e);
}
@@ -76,6 +84,11 @@ public class SparkSaslClient {
return saslClient != null && saslClient.isComplete();
}
+ /** Returns the value of a negotiated property. */
+ public Object getNegotiatedProperty(String name) {
+ return saslClient.getNegotiatedProperty(name);
+ }
+
/**
* Respond to server's SASL token.
* @param token contains server's SASL token
@@ -93,6 +106,7 @@ public class SparkSaslClient {
* Disposes of any system resources or security-sensitive information the
* SaslClient might be using.
*/
+ @Override
public synchronized void dispose() {
if (saslClient != null) {
try {
@@ -134,4 +148,15 @@ public class SparkSaslClient {
}
}
}
+
+ @Override
+ public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
+ return saslClient.wrap(data, offset, len);
+ }
+
+ @Override
+ public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
+ return saslClient.unwrap(data, offset, len);
+ }
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
index e87b17ead1..431cb67a2a 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
@@ -44,7 +44,7 @@ import org.slf4j.LoggerFactory;
* initial state to the "authenticated" state. (It is not a server in the sense of accepting
* connections on some socket.)
*/
-public class SparkSaslServer {
+public class SparkSaslServer implements SaslEncryptionBackend {
private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class);
/**
@@ -60,26 +60,37 @@ public class SparkSaslServer {
static final String DIGEST = "DIGEST-MD5";
/**
- * The quality of protection is just "auth". This means that we are doing
- * authentication only, we are not supporting integrity or privacy protection of the
- * communication channel after authentication. This could be changed to be configurable
- * in the future.
+ * Quality of protection value that includes encryption.
*/
- static final Map<String, String> SASL_PROPS = ImmutableMap.<String, String>builder()
- .put(Sasl.QOP, "auth")
- .put(Sasl.SERVER_AUTH, "true")
- .build();
+ static final String QOP_AUTH_CONF = "auth-conf";
+
+ /**
+ * Quality of protection value that does not include encryption.
+ */
+ static final String QOP_AUTH = "auth";
/** Identifier for a certain secret key within the secretKeyHolder. */
private final String secretKeyId;
private final SecretKeyHolder secretKeyHolder;
private SaslServer saslServer;
- public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) {
+ public SparkSaslServer(
+ String secretKeyId,
+ SecretKeyHolder secretKeyHolder,
+ boolean alwaysEncrypt) {
this.secretKeyId = secretKeyId;
this.secretKeyHolder = secretKeyHolder;
+
+ // Sasl.QOP is a comma-separated list of supported values. The value that allows encryption
+ // is listed first since it's preferred over the non-encrypted one (if the client also
+ // lists both in the request).
+ String qop = alwaysEncrypt ? QOP_AUTH_CONF : String.format("%s,%s", QOP_AUTH_CONF, QOP_AUTH);
+ Map<String, String> saslProps = ImmutableMap.<String, String>builder()
+ .put(Sasl.SERVER_AUTH, "true")
+ .put(Sasl.QOP, qop)
+ .build();
try {
- this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS,
+ this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, saslProps,
new DigestCallbackHandler());
} catch (SaslException e) {
throw Throwables.propagate(e);
@@ -93,6 +104,11 @@ public class SparkSaslServer {
return saslServer != null && saslServer.isComplete();
}
+ /** Returns the value of a negotiated property. */
+ public Object getNegotiatedProperty(String name) {
+ return saslServer.getNegotiatedProperty(name);
+ }
+
/**
* Used to respond to server SASL tokens.
* @param token Server's SASL token
@@ -110,6 +126,7 @@ public class SparkSaslServer {
* Disposes of any system resources or security-sensitive information the
* SaslServer might be using.
*/
+ @Override
public synchronized void dispose() {
if (saslServer != null) {
try {
@@ -122,6 +139,16 @@ public class SparkSaslServer {
}
}
+ @Override
+ public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
+ return saslServer.wrap(data, offset, len);
+ }
+
+ @Override
+ public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
+ return saslServer.unwrap(data, offset, len);
+ }
+
/**
* Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism.
*/
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
index b7ce8541e5..941ef95772 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -19,8 +19,11 @@ package org.apache.spark.network.server;
import java.io.Closeable;
import java.net.InetSocketAddress;
+import java.util.List;
import java.util.concurrent.TimeUnit;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.ChannelFuture;
@@ -44,15 +47,23 @@ public class TransportServer implements Closeable {
private final TransportContext context;
private final TransportConf conf;
+ private final RpcHandler appRpcHandler;
+ private final List<TransportServerBootstrap> bootstraps;
private ServerBootstrap bootstrap;
private ChannelFuture channelFuture;
private int port = -1;
/** Creates a TransportServer that binds to the given port, or to any available if 0. */
- public TransportServer(TransportContext context, int portToBind) {
+ public TransportServer(
+ TransportContext context,
+ int portToBind,
+ RpcHandler appRpcHandler,
+ List<TransportServerBootstrap> bootstraps) {
this.context = context;
this.conf = context.getConf();
+ this.appRpcHandler = appRpcHandler;
+ this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps));
init(portToBind);
}
@@ -95,7 +106,11 @@ public class TransportServer implements Closeable {
bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
- context.initializePipeline(ch);
+ RpcHandler rpcHandler = appRpcHandler;
+ for (TransportServerBootstrap bootstrap : bootstraps) {
+ rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
+ }
+ context.initializePipeline(ch, rpcHandler);
}
});
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java
new file mode 100644
index 0000000000..05803ab1bb
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import io.netty.channel.Channel;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a client connects
+ * to the server. This allows customizing the client channel to allow for things such as SASL
+ * authentication.
+ */
+public interface TransportServerBootstrap {
+ /**
+ * Customizes the channel to include new features, if needed.
+ *
+ * @param channel The connected channel opened by the client.
+ * @param rpcHandler The RPC handler for the server.
+ * @return The RPC handler to use for the channel.
+ */
+ RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler);
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java b/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java
index b525ed69fc..b141572004 100644
--- a/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java
@@ -15,11 +15,14 @@
* limitations under the License.
*/
-package org.apache.spark.network;
+package org.apache.spark.network.util;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
+/**
+ * A writable channel that stores the written data in a byte array in memory.
+ */
public class ByteArrayWritableChannel implements WritableByteChannel {
private final byte[] data;
@@ -27,19 +30,30 @@ public class ByteArrayWritableChannel implements WritableByteChannel {
public ByteArrayWritableChannel(int size) {
this.data = new byte[size];
- this.offset = 0;
}
public byte[] getData() {
return data;
}
+ public int length() {
+ return offset;
+ }
+
+ /** Resets the channel so that writing to it will overwrite the existing buffer. */
+ public void reset() {
+ offset = 0;
+ }
+
+ /**
+ * Reads from the given buffer into the internal byte array.
+ */
@Override
public int write(ByteBuffer src) {
- int available = src.remaining();
- src.get(data, offset, available);
- offset += available;
- return available;
+ int toTransfer = Math.min(src.remaining(), data.length - offset);
+ src.get(data, offset, toTransfer);
+ offset += toTransfer;
+ return toTransfer;
}
@Override
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 0aef7f1987..3b2eff3779 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -17,6 +17,8 @@
package org.apache.spark.network.util;
+import com.google.common.primitives.Ints;
+
/**
* A central location that tracks all the settings we expose to users.
*/
@@ -112,4 +114,20 @@ public class TransportConf {
public int portMaxRetries() {
return conf.getInt("spark.port.maxRetries", 16);
}
+
+ /**
+ * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled.
+ */
+ public int maxSaslEncryptedBlockSize() {
+ return Ints.checkedCast(JavaUtils.byteStringAsBytes(
+ conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k")));
+ }
+
+ /**
+ * Whether the server should enforce encryption on SASL-authenticated connections.
+ */
+ public boolean saslServerAlwaysEncrypt() {
+ return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
+ }
+
}
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index 860dd6d9b3..d500bc3c98 100644
--- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -39,6 +39,7 @@ import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.NettyUtils;
public class ProtocolSuite {
diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
index ff985096d7..6c98e733b4 100644
--- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
@@ -29,7 +29,7 @@ import org.junit.Test;
import static org.junit.Assert.*;
-import org.apache.spark.network.ByteArrayWritableChannel;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
public class MessageWithHeaderSuite {
diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index 23b4e06f06..be6632bb8c 100644
--- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -17,12 +17,47 @@
package org.apache.spark.network.sasl;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import static com.google.common.base.Charsets.UTF_8;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+import java.io.File;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import javax.security.sasl.SaslException;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.ByteStreams;
+import com.google.common.io.Files;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPromise;
import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
/**
* Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes.
@@ -44,8 +79,8 @@ public class SparkSaslSuite {
@Test
public void testMatching() {
- SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder);
- SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder);
+ SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder, false);
+ SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder, false);
assertFalse(client.isComplete());
assertFalse(server.isComplete());
@@ -64,11 +99,10 @@ public class SparkSaslSuite {
assertFalse(client.isComplete());
}
-
@Test
public void testNonMatching() {
- SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder);
- SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder);
+ SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder, false);
+ SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder, false);
assertFalse(client.isComplete());
assertFalse(server.isComplete());
@@ -86,4 +120,312 @@ public class SparkSaslSuite {
assertFalse(server.isComplete());
}
}
+
+ @Test
+ public void testSaslAuthentication() throws Exception {
+ testBasicSasl(false);
+ }
+
+ @Test
+ public void testSaslEncryption() throws Exception {
+ testBasicSasl(true);
+ }
+
+ private void testBasicSasl(boolean encrypt) throws Exception {
+ RpcHandler rpcHandler = mock(RpcHandler.class);
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocation) {
+ byte[] message = (byte[]) invocation.getArguments()[1];
+ RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2];
+ assertEquals("Ping", new String(message, UTF_8));
+ cb.onSuccess("Pong".getBytes(UTF_8));
+ return null;
+ }
+ })
+ .when(rpcHandler)
+ .receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class));
+
+ SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
+ try {
+ byte[] response = ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10));
+ assertEquals("Pong", new String(response, UTF_8));
+ } finally {
+ ctx.close();
+ }
+ }
+
+ @Test
+ public void testEncryptedMessage() throws Exception {
+ SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
+ byte[] data = new byte[1024];
+ new Random().nextBytes(data);
+ when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
+
+ ByteBuf msg = Unpooled.buffer();
+ try {
+ msg.writeBytes(data);
+
+ // Create a channel with a really small buffer compared to the data. This means that on each
+ // call, the outbound data will not be fully written, so the write() method should return a
+ // dummy count to keep the channel alive when possible.
+ ByteArrayWritableChannel channel = new ByteArrayWritableChannel(32);
+
+ SaslEncryption.EncryptedMessage emsg =
+ new SaslEncryption.EncryptedMessage(backend, msg, 1024);
+ long count = emsg.transferTo(channel, emsg.transfered());
+ assertTrue(count < data.length);
+ assertTrue(count > 0);
+
+ // Here, the output buffer is full so nothing should be transferred.
+ assertEquals(0, emsg.transferTo(channel, emsg.transfered()));
+
+ // Now there's room in the buffer, but not enough to transfer all the remaining data,
+ // so the dummy count should be returned.
+ channel.reset();
+ assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
+
+ // Eventually, the whole message should be transferred.
+ for (int i = 0; i < data.length / 32 - 2; i++) {
+ channel.reset();
+ assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
+ }
+
+ channel.reset();
+ count = emsg.transferTo(channel, emsg.transfered());
+ assertTrue("Unexpected count: " + count, count > 1 && count < data.length);
+ assertEquals(data.length, emsg.transfered());
+ } finally {
+ msg.release();
+ }
+ }
+
+ @Test
+ public void testEncryptedMessageChunking() throws Exception {
+ File file = File.createTempFile("sasltest", ".txt");
+ try {
+ TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+
+ byte[] data = new byte[8 * 1024];
+ new Random().nextBytes(data);
+ Files.write(data, file);
+
+ SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
+ // It doesn't really matter what we return here, as long as it's not null.
+ when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
+
+ FileSegmentManagedBuffer msg = new FileSegmentManagedBuffer(conf, file, 0, file.length());
+ SaslEncryption.EncryptedMessage emsg =
+ new SaslEncryption.EncryptedMessage(backend, msg.convertToNetty(), data.length / 8);
+
+ ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length);
+ while (emsg.transfered() < emsg.count()) {
+ channel.reset();
+ emsg.transferTo(channel, emsg.transfered());
+ }
+
+ verify(backend, times(8)).wrap(any(byte[].class), anyInt(), anyInt());
+ } finally {
+ file.delete();
+ }
+ }
+
+ @Test
+ public void testFileRegionEncryption() throws Exception {
+ final String blockSizeConf = "spark.network.sasl.maxEncryptedBlockSize";
+ System.setProperty(blockSizeConf, "1k");
+
+ final AtomicReference<ManagedBuffer> response = new AtomicReference();
+ final File file = File.createTempFile("sasltest", ".txt");
+ SaslTestCtx ctx = null;
+ try {
+ final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+ StreamManager sm = mock(StreamManager.class);
+ when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer<ManagedBuffer>() {
+ @Override
+ public ManagedBuffer answer(InvocationOnMock invocation) {
+ return new FileSegmentManagedBuffer(conf, file, 0, file.length());
+ }
+ });
+
+ RpcHandler rpcHandler = mock(RpcHandler.class);
+ when(rpcHandler.getStreamManager()).thenReturn(sm);
+
+ byte[] data = new byte[8 * 1024];
+ new Random().nextBytes(data);
+ Files.write(data, file);
+
+ ctx = new SaslTestCtx(rpcHandler, true, false);
+
+ final Object lock = new Object();
+
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocation) {
+ response.set((ManagedBuffer) invocation.getArguments()[1]);
+ response.get().retain();
+ synchronized (lock) {
+ lock.notifyAll();
+ }
+ return null;
+ }
+ }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
+
+ synchronized (lock) {
+ ctx.client.fetchChunk(0, 0, callback);
+ lock.wait(10 * 1000);
+ }
+
+ verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
+ verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
+
+ byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
+ assertTrue(Arrays.equals(data, received));
+ } finally {
+ file.delete();
+ if (ctx != null) {
+ ctx.close();
+ }
+ if (response.get() != null) {
+ response.get().release();
+ }
+ System.clearProperty(blockSizeConf);
+ }
+ }
+
+ @Test
+ public void testServerAlwaysEncrypt() throws Exception {
+ final String alwaysEncryptConfName = "spark.network.sasl.serverAlwaysEncrypt";
+ System.setProperty(alwaysEncryptConfName, "true");
+
+ SaslTestCtx ctx = null;
+ try {
+ ctx = new SaslTestCtx(mock(RpcHandler.class), false, false);
+ fail("Should have failed to connect without encryption.");
+ } catch (Exception e) {
+ assertTrue(e.getCause() instanceof SaslException);
+ } finally {
+ if (ctx != null) {
+ ctx.close();
+ }
+ System.clearProperty(alwaysEncryptConfName);
+ }
+ }
+
+ @Test
+ public void testDataEncryptionIsActuallyEnabled() throws Exception {
+ // This test sets up an encrypted connection but then, using a client bootstrap, removes
+ // the encryption handler from the client side. This should cause the server to not be
+ // able to understand RPCs sent to it and thus close the connection.
+ SaslTestCtx ctx = null;
+ try {
+ ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
+ ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10));
+ fail("Should have failed to send RPC to server.");
+ } catch (Exception e) {
+ assertFalse(e.getCause() instanceof TimeoutException);
+ } finally {
+ if (ctx != null) {
+ ctx.close();
+ }
+ }
+ }
+
+ private static class SaslTestCtx {
+
+ final TransportClient client;
+ final TransportServer server;
+
+ private final boolean encrypt;
+ private final boolean disableClientEncryption;
+ private final EncryptionCheckerBootstrap checker;
+
+ SaslTestCtx(
+ RpcHandler rpcHandler,
+ boolean encrypt,
+ boolean disableClientEncryption)
+ throws Exception {
+
+ TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+
+ SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
+ when(keyHolder.getSaslUser(anyString())).thenReturn("user");
+ when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
+
+ TransportContext ctx = new TransportContext(conf, rpcHandler);
+
+ this.checker = new EncryptionCheckerBootstrap();
+ this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder),
+ checker));
+
+ try {
+ List<TransportClientBootstrap> clientBootstraps = Lists.newArrayList();
+ clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt));
+ if (disableClientEncryption) {
+ clientBootstraps.add(new EncryptionDisablerBootstrap());
+ }
+
+ this.client = ctx.createClientFactory(clientBootstraps)
+ .createClient(TestUtils.getLocalHost(), server.getPort());
+ } catch (Exception e) {
+ close();
+ throw e;
+ }
+
+ this.encrypt = encrypt;
+ this.disableClientEncryption = disableClientEncryption;
+ }
+
+ void close() {
+ if (!disableClientEncryption) {
+ assertEquals(encrypt, checker.foundEncryptionHandler);
+ }
+ if (client != null) {
+ client.close();
+ }
+ if (server != null) {
+ server.close();
+ }
+ }
+
+ }
+
+ private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAdapter
+ implements TransportServerBootstrap {
+
+ boolean foundEncryptionHandler;
+
+ @Override
+ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
+ throws Exception {
+ if (!foundEncryptionHandler) {
+ foundEncryptionHandler =
+ ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null;
+ }
+ ctx.write(msg, promise);
+ }
+
+ @Override
+ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
+ super.handlerRemoved(ctx);
+ }
+
+ @Override
+ public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
+ channel.pipeline().addFirst("encryptionChecker", this);
+ return rpcHandler;
+ }
+
+ }
+
+ private static class EncryptionDisablerBootstrap implements TransportClientBootstrap {
+
+ @Override
+ public void doBootstrap(TransportClient client, Channel channel) {
+ channel.pipeline().remove(SaslEncryption.ENCRYPTION_HANDLER_NAME);
+ }
+
+ }
+
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 6e8018b723..612bce571a 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -20,6 +20,7 @@ package org.apache.spark.network.shuffle;
import java.io.IOException;
import java.util.List;
+import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -46,6 +47,7 @@ public class ExternalShuffleClient extends ShuffleClient {
private final TransportConf conf;
private final boolean saslEnabled;
+ private final boolean saslEncryptionEnabled;
private final SecretKeyHolder secretKeyHolder;
private TransportClientFactory clientFactory;
@@ -58,10 +60,15 @@ public class ExternalShuffleClient extends ShuffleClient {
public ExternalShuffleClient(
TransportConf conf,
SecretKeyHolder secretKeyHolder,
- boolean saslEnabled) {
+ boolean saslEnabled,
+ boolean saslEncryptionEnabled) {
+ Preconditions.checkArgument(
+ !saslEncryptionEnabled || saslEnabled,
+ "SASL encryption can only be enabled if SASL is also enabled.");
this.conf = conf;
this.secretKeyHolder = secretKeyHolder;
this.saslEnabled = saslEnabled;
+ this.saslEncryptionEnabled = saslEncryptionEnabled;
}
@Override
@@ -70,7 +77,7 @@ public class ExternalShuffleClient extends ShuffleClient {
TransportContext context = new TransportContext(conf, new NoOpRpcHandler());
List<TransportClientBootstrap> bootstraps = Lists.newArrayList();
if (saslEnabled) {
- bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder));
+ bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled));
}
clientFactory = context.createClientFactory(bootstraps);
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index d25283e46e..382f613ecb 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.network.sasl;
import java.io.IOException;
+import java.util.Arrays;
import com.google.common.collect.Lists;
import org.junit.After;
@@ -37,6 +38,7 @@ import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
@@ -72,10 +74,11 @@ public class SaslIntegrationSuite {
@BeforeClass
public static void beforeAll() throws IOException {
SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key");
- SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder);
conf = new TransportConf(new SystemPropertyConfigProvider());
- context = new TransportContext(conf, handler);
- server = context.createServer();
+ context = new TransportContext(conf, new TestRpcHandler());
+
+ TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
+ server = context.createServer(Arrays.asList(bootstrap));
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index 02c10bcb7b..39aa49911d 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -136,7 +136,7 @@ public class ExternalShuffleIntegrationSuite {
final Semaphore requestsRemaining = new Semaphore(0);
- ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
+ ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false);
client.init(APP_ID);
client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
new BlockFetchingListener() {
@@ -274,7 +274,7 @@ public class ExternalShuffleIntegrationSuite {
private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo)
throws IOException {
- ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
+ ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false);
client.init(APP_ID);
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
executorId, executorInfo);
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
index 759a12910c..d4ec1956c1 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.network.shuffle;
import java.io.IOException;
+import java.util.Arrays;
import org.junit.After;
import org.junit.Before;
@@ -27,10 +28,11 @@ import static org.junit.Assert.*;
import org.apache.spark.network.TestUtils;
import org.apache.spark.network.TransportContext;
-import org.apache.spark.network.sasl.SaslRpcHandler;
+import org.apache.spark.network.sasl.SaslServerBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
@@ -42,10 +44,10 @@ public class ExternalShuffleSecuritySuite {
@Before
public void beforeEach() {
- RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(conf),
- new TestSecretKeyHolder("my-app-id", "secret"));
- TransportContext context = new TransportContext(conf, handler);
- this.server = context.createServer();
+ TransportContext context = new TransportContext(conf, new ExternalShuffleBlockHandler(conf));
+ TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf,
+ new TestSecretKeyHolder("my-app-id", "secret"));
+ this.server = context.createServer(Arrays.asList(bootstrap));
}
@After
@@ -58,13 +60,13 @@ public class ExternalShuffleSecuritySuite {
@Test
public void testValid() throws IOException {
- validate("my-app-id", "secret");
+ validate("my-app-id", "secret", false);
}
@Test
public void testBadAppId() {
try {
- validate("wrong-app-id", "secret");
+ validate("wrong-app-id", "secret", false);
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!"));
}
@@ -73,16 +75,21 @@ public class ExternalShuffleSecuritySuite {
@Test
public void testBadSecret() {
try {
- validate("my-app-id", "bad-secret");
+ validate("my-app-id", "bad-secret", false);
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
}
}
+ @Test
+ public void testEncryption() throws IOException {
+ validate("my-app-id", "secret", true);
+ }
+
/** Creates an ExternalShuffleClient and attempts to register with the server. */
- private void validate(String appId, String secretKey) throws IOException {
+ private void validate(String appId, String secretKey, boolean encrypt) throws IOException {
ExternalShuffleClient client =
- new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true);
+ new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true, encrypt);
client.init(appId);
// Registration either succeeds or throws an exception.
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0",
diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
index 63b21222e7..463f99ef33 100644
--- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
+++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
@@ -17,9 +17,10 @@
package org.apache.spark.network.yarn;
-import java.lang.Override;
import java.nio.ByteBuffer;
+import java.util.List;
+import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ContainerId;
@@ -32,10 +33,11 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.TransportContext;
-import org.apache.spark.network.sasl.SaslRpcHandler;
+import org.apache.spark.network.sasl.SaslServerBootstrap;
import org.apache.spark.network.sasl.ShuffleSecretManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
import org.apache.spark.network.util.TransportConf;
import org.apache.spark.network.yarn.util.HadoopConfigProvider;
@@ -103,16 +105,17 @@ public class YarnShuffleService extends AuxiliaryService {
// special RPC handler that filters out unauthenticated fetch requests
boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE);
blockHandler = new ExternalShuffleBlockHandler(transportConf);
- RpcHandler rpcHandler = blockHandler;
+
+ List<TransportServerBootstrap> bootstraps = Lists.newArrayList();
if (authEnabled) {
secretManager = new ShuffleSecretManager();
- rpcHandler = new SaslRpcHandler(rpcHandler, secretManager);
+ bootstraps.add(new SaslServerBootstrap(transportConf, secretManager));
}
int port = conf.getInt(
SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT);
- TransportContext transportContext = new TransportContext(transportConf, rpcHandler);
- shuffleServer = transportContext.createServer(port);
+ TransportContext transportContext = new TransportContext(transportConf, blockHandler);
+ shuffleServer = transportContext.createServer(port, bootstraps);
String authEnabledString = authEnabled ? "enabled" : "not enabled";
logger.info("Started YARN shuffle service for Spark on port {}. " +
"Authentication is {}.", port, authEnabledString);