aboutsummaryrefslogtreecommitdiff
path: root/network
diff options
context:
space:
mode:
authorAaron Davidson <aaron@databricks.com>2014-11-04 16:15:38 -0800
committerReynold Xin <rxin@databricks.com>2014-11-04 16:15:38 -0800
commit5e73138a0152b78380b3f1def4b969b58e70dd11 (patch)
treed899e389f29e7c87c8f40e84872477a9b5e52277 /network
parentf90ad5d426cb726079c490a9bb4b1100e2b4e602 (diff)
downloadspark-5e73138a0152b78380b3f1def4b969b58e70dd11.tar.gz
spark-5e73138a0152b78380b3f1def4b969b58e70dd11.tar.bz2
spark-5e73138a0152b78380b3f1def4b969b58e70dd11.zip
[SPARK-2938] Support SASL authentication in NettyBlockTransferService
Also lays the groundwork for supporting it inside the external shuffle service. Author: Aaron Davidson <aaron@databricks.com> Closes #3087 from aarondav/sasl and squashes the following commits: 3481718 [Aaron Davidson] Delete rogue println 44f8410 [Aaron Davidson] Delete documentation - muahaha! eb9f065 [Aaron Davidson] Improve documentation and add end-to-end test at Spark-level a6b95f1 [Aaron Davidson] Address comments 785bbde [Aaron Davidson] Cleanup 79973cb [Aaron Davidson] Remove unused file 151b3c5 [Aaron Davidson] Add docs, timeout config, better failure handling f6177d7 [Aaron Davidson] Cleanup SASL state upon connection termination 7b42adb [Aaron Davidson] Add unit tests 8191bcb [Aaron Davidson] [SPARK-2938] Support SASL authentication in NettyBlockTransferService
Diffstat (limited to 'network')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/TransportContext.java15
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java11
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java32
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java64
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java19
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java1
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportConf.java3
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java74
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java74
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java97
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java35
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java138
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java170
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java2
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java15
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java11
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java172
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java89
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java7
20 files changed, 999 insertions, 32 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index a271841e4e..5bc6e5a241 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -17,12 +17,16 @@
package org.apache.spark.network;
+import java.util.List;
+
+import com.google.common.collect.Lists;
import io.netty.channel.Channel;
import io.netty.channel.socket.SocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.client.TransportResponseHandler;
import org.apache.spark.network.protocol.MessageDecoder;
@@ -64,8 +68,17 @@ public class TransportContext {
this.decoder = new MessageDecoder();
}
+ /**
+ * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning
+ * a new Client. Bootstraps will be executed synchronously, and must run successfully in order
+ * to create a Client.
+ */
+ public TransportClientFactory createClientFactory(List<TransportClientBootstrap> bootstraps) {
+ return new TransportClientFactory(this, bootstraps);
+ }
+
public TransportClientFactory createClientFactory() {
- return new TransportClientFactory(this);
+ return createClientFactory(Lists.<TransportClientBootstrap>newArrayList());
}
/** Create a server which will attempt to bind to a specific port. */
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index 01c143fff4..a08cee02dd 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -19,10 +19,9 @@ package org.apache.spark.network.client;
import java.io.Closeable;
import java.util.UUID;
-import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
+import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.util.concurrent.SettableFuture;
@@ -186,4 +185,12 @@ public class TransportClient implements Closeable {
// close is a local operation and should finish with milliseconds; timeout just to be safe
channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
}
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("remoteAdress", channel.remoteAddress())
+ .add("isActive", isActive())
+ .toString();
+ }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
new file mode 100644
index 0000000000..65e8020e34
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * A bootstrap which is executed on a TransportClient before it is returned to the user.
+ * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per-
+ * connection basis.
+ *
+ * Since connections (and TransportClients) are reused as much as possible, it is generally
+ * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with
+ * the JVM itself.
+ */
+public interface TransportClientBootstrap {
+ /** Performs the bootstrapping operation, throwing an exception on failure. */
+ public void doBootstrap(TransportClient client) throws RuntimeException;
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 0b4a1d8286..1723fed307 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -21,10 +21,14 @@ import java.io.Closeable;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
+import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.Lists;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
@@ -40,6 +44,7 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;
@@ -47,22 +52,29 @@ import org.apache.spark.network.util.TransportConf;
* Factory for creating {@link TransportClient}s by using createClient.
*
* The factory maintains a connection pool to other hosts and should return the same
- * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for
- * all {@link TransportClient}s.
+ * TransportClient for the same remote host. It also shares a single worker thread pool for
+ * all TransportClients.
+ *
+ * TransportClients will be reused whenever possible. Prior to completing the creation of a new
+ * TransportClient, all given {@link TransportClientBootstrap}s will be run.
*/
public class TransportClientFactory implements Closeable {
private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
private final TransportContext context;
private final TransportConf conf;
+ private final List<TransportClientBootstrap> clientBootstraps;
private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool;
private final Class<? extends Channel> socketChannelClass;
private EventLoopGroup workerGroup;
- public TransportClientFactory(TransportContext context) {
- this.context = context;
+ public TransportClientFactory(
+ TransportContext context,
+ List<TransportClientBootstrap> clientBootstraps) {
+ this.context = Preconditions.checkNotNull(context);
this.conf = context.getConf();
+ this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>();
IOMode ioMode = IOMode.valueOf(conf.ioMode());
@@ -72,9 +84,12 @@ public class TransportClientFactory implements Closeable {
}
/**
- * Create a new BlockFetchingClient connecting to the given remote host / port.
+ * Create a new {@link TransportClient} connecting to the given remote host / port. This will
+ * reuse TransportClients if they are still active and are for the same remote address. Prior
+ * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s
+ * that are registered with this factory.
*
- * This blocks until a connection is successfully established.
+ * This blocks until a connection is successfully established and fully bootstrapped.
*
* Concurrency: This method is safe to call from multiple threads.
*/
@@ -104,17 +119,18 @@ public class TransportClientFactory implements Closeable {
// Use pooled buffers to reduce temporary buffer allocation
bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator());
- final AtomicReference<TransportClient> client = new AtomicReference<TransportClient>();
+ final AtomicReference<TransportClient> clientRef = new AtomicReference<TransportClient>();
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {
TransportChannelHandler clientHandler = context.initializePipeline(ch);
- client.set(clientHandler.getClient());
+ clientRef.set(clientHandler.getClient());
}
});
// Connect to the remote server
+ long preConnect = System.currentTimeMillis();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
throw new RuntimeException(
@@ -123,15 +139,35 @@ public class TransportClientFactory implements Closeable {
throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
}
- // Successful connection -- in the event that two threads raced to create a client, we will
+ TransportClient client = clientRef.get();
+ assert client != null : "Channel future completed successfully with null client";
+
+ // Execute any client bootstraps synchronously before marking the Client as successful.
+ long preBootstrap = System.currentTimeMillis();
+ logger.debug("Connection to {} successful, running bootstraps...", address);
+ try {
+ for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
+ clientBootstrap.doBootstrap(client);
+ }
+ } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
+ long bootstrapTime = System.currentTimeMillis() - preBootstrap;
+ logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e);
+ client.close();
+ throw Throwables.propagate(e);
+ }
+ long postBootstrap = System.currentTimeMillis();
+
+ // Successful connection & bootstrap -- in the event that two threads raced to create a client,
// use the first one that was put into the connectionPool and close the one we made here.
- assert client.get() != null : "Channel future completed successfully with null client";
- TransportClient oldClient = connectionPool.putIfAbsent(address, client.get());
+ TransportClient oldClient = connectionPool.putIfAbsent(address, client);
if (oldClient == null) {
- return client.get();
+ logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
+ address, postBootstrap - preConnect, postBootstrap - preBootstrap);
+ return client;
} else {
- logger.debug("Two clients were created concurrently, second one will be disposed.");
- client.get().close();
+ logger.debug("Two clients were created concurrently after {} ms, second will be disposed.",
+ postBootstrap - preConnect);
+ client.close();
return oldClient;
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
index 5a3f003726..1502b7489e 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
@@ -21,7 +21,7 @@ import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */
-public class NoOpRpcHandler implements RpcHandler {
+public class NoOpRpcHandler extends RpcHandler {
private final StreamManager streamManager;
public NoOpRpcHandler() {
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
index 2369dc6203..2ba92a40f8 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -23,22 +23,33 @@ import org.apache.spark.network.client.TransportClient;
/**
* Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
*/
-public interface RpcHandler {
+public abstract class RpcHandler {
/**
* Receive a single RPC message. Any exception thrown while in this method will be sent back to
* the client in string form as a standard RPC failure.
*
+ * This method will not be called in parallel for a single TransportClient (i.e., channel).
+ *
* @param client A channel client which enables the handler to make requests back to the sender
- * of this RPC.
+ * of this RPC. This will always be the exact same object for a particular channel.
* @param message The serialized bytes of the RPC.
* @param callback Callback which should be invoked exactly once upon success or failure of the
* RPC.
*/
- void receive(TransportClient client, byte[] message, RpcResponseCallback callback);
+ public abstract void receive(
+ TransportClient client,
+ byte[] message,
+ RpcResponseCallback callback);
/**
* Returns the StreamManager which contains the state about which streams are currently being
* fetched by a TransportClient.
*/
- StreamManager getStreamManager();
+ public abstract StreamManager getStreamManager();
+
+ /**
+ * Invoked when the connection associated with the given client has been invalidated.
+ * No further requests will come from this client.
+ */
+ public void connectionTerminated(TransportClient client) { }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 17fe9001b3..1580180cc1 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -86,6 +86,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
for (long streamId : streamIds) {
streamManager.connectionTerminated(streamId);
}
+ rpcHandler.connectionTerminated(reverseClient);
}
@Override
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
index a68f38e0e9..823790dd3c 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -55,4 +55,7 @@ public class TransportConf {
/** Send buffer size (SO_SNDBUF). */
public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); }
+
+ /** Timeout for a single round trip of SASL token exchange, in milliseconds. */
+ public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); }
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
new file mode 100644
index 0000000000..7bc91e3753
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The
+ * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId.
+ */
+public class SaslClientBootstrap implements TransportClientBootstrap {
+ private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);
+
+ private final TransportConf conf;
+ private final String appId;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.appId = appId;
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ /**
+ * Performs SASL authentication by sending a token, and then proceeding with the SASL
+ * challenge-response tokens until we either successfully authenticate or throw an exception
+ * due to mismatch.
+ */
+ @Override
+ public void doBootstrap(TransportClient client) {
+ SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder);
+ try {
+ byte[] payload = saslClient.firstToken();
+
+ while (!saslClient.isComplete()) {
+ SaslMessage msg = new SaslMessage(appId, payload);
+ ByteBuf buf = Unpooled.buffer(msg.encodedLength());
+ msg.encode(buf);
+
+ byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeout());
+ payload = saslClient.response(response);
+ }
+ } finally {
+ try {
+ // Once authentication is complete, the server will trust all remaining communication.
+ saslClient.dispose();
+ } catch (RuntimeException e) {
+ logger.error("Error while disposing SASL client", e);
+ }
+ }
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
new file mode 100644
index 0000000000..5b77e18c26
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import com.google.common.base.Charsets;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encodable;
+
+/**
+ * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
+ * with the given appId. This appId allows a single SaslRpcHandler to multiplex different
+ * applications which may be using different sets of credentials.
+ */
+class SaslMessage implements Encodable {
+
+ /** Serialization tag used to catch incorrect payloads. */
+ private static final byte TAG_BYTE = (byte) 0xEA;
+
+ public final String appId;
+ public final byte[] payload;
+
+ public SaslMessage(String appId, byte[] payload) {
+ this.appId = appId;
+ this.payload = payload;
+ }
+
+ @Override
+ public int encodedLength() {
+ // tag + appIdLength + appId + payloadLength + payload
+ return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeByte(TAG_BYTE);
+ byte[] idBytes = appId.getBytes(Charsets.UTF_8);
+ buf.writeInt(idBytes.length);
+ buf.writeBytes(idBytes);
+ buf.writeInt(payload.length);
+ buf.writeBytes(payload);
+ }
+
+ public static SaslMessage decode(ByteBuf buf) {
+ if (buf.readByte() != TAG_BYTE) {
+ throw new IllegalStateException("Expected SaslMessage, received something else");
+ }
+
+ int idLength = buf.readInt();
+ byte[] idBytes = new byte[idLength];
+ buf.readBytes(idBytes);
+
+ int payloadLength = buf.readInt();
+ byte[] payload = new byte[payloadLength];
+ buf.readBytes(payload);
+
+ return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload);
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
new file mode 100644
index 0000000000..3777a18e33
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.util.concurrent.ConcurrentMap;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Maps;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+
+/**
+ * RPC Handler which performs SASL authentication before delegating to a child RPC handler.
+ * The delegate will only receive messages if the given connection has been successfully
+ * authenticated. A connection may be authenticated at most once.
+ *
+ * Note that the authentication process consists of multiple challenge-response pairs, each of
+ * which are individual RPCs.
+ */
+public class SaslRpcHandler extends RpcHandler {
+ private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
+
+ /** RpcHandler we will delegate to for authenticated connections. */
+ private final RpcHandler delegate;
+
+ /** Class which provides secret keys which are shared by server and client on a per-app basis. */
+ private final SecretKeyHolder secretKeyHolder;
+
+ /** Maps each channel to its SASL authentication state. */
+ private final ConcurrentMap<TransportClient, SparkSaslServer> channelAuthenticationMap;
+
+ public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) {
+ this.delegate = delegate;
+ this.secretKeyHolder = secretKeyHolder;
+ this.channelAuthenticationMap = Maps.newConcurrentMap();
+ }
+
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ SparkSaslServer saslServer = channelAuthenticationMap.get(client);
+ if (saslServer != null && saslServer.isComplete()) {
+ // Authentication complete, delegate to base handler.
+ delegate.receive(client, message, callback);
+ return;
+ }
+
+ SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message));
+
+ if (saslServer == null) {
+ // First message in the handshake, setup the necessary state.
+ saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder);
+ channelAuthenticationMap.put(client, saslServer);
+ }
+
+ byte[] response = saslServer.response(saslMessage.payload);
+ if (saslServer.isComplete()) {
+ logger.debug("SASL authentication successful for channel {}", client);
+ }
+ callback.onSuccess(response);
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return delegate.getStreamManager();
+ }
+
+ @Override
+ public void connectionTerminated(TransportClient client) {
+ SparkSaslServer saslServer = channelAuthenticationMap.remove(client);
+ if (saslServer != null) {
+ saslServer.dispose();
+ }
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
new file mode 100644
index 0000000000..81d5766794
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+/**
+ * Interface for getting a secret key associated with some application.
+ */
+public interface SecretKeyHolder {
+ /**
+ * Gets an appropriate SASL User for the given appId.
+ * @throws IllegalArgumentException if the given appId is not associated with a SASL user.
+ */
+ String getSaslUser(String appId);
+
+ /**
+ * Gets an appropriate SASL secret key for the given appId.
+ * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key.
+ */
+ String getSecretKey(String appId);
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
new file mode 100644
index 0000000000..72ba737b99
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.RealmChoiceCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+import java.io.IOException;
+
+import com.google.common.base.Throwables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.spark.network.sasl.SparkSaslServer.*;
+
+/**
+ * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the
+ * initial state to the "authenticated" state. This client initializes the protocol via a
+ * firstToken, which is then followed by a set of challenges and responses.
+ */
+public class SparkSaslClient {
+ private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class);
+
+ private final String secretKeyId;
+ private final SecretKeyHolder secretKeyHolder;
+ private SaslClient saslClient;
+
+ public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) {
+ this.secretKeyId = secretKeyId;
+ this.secretKeyHolder = secretKeyHolder;
+ try {
+ this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM,
+ SASL_PROPS, new ClientCallbackHandler());
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /** Used to initiate SASL handshake with server. */
+ public synchronized byte[] firstToken() {
+ if (saslClient != null && saslClient.hasInitialResponse()) {
+ try {
+ return saslClient.evaluateChallenge(new byte[0]);
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ } else {
+ return new byte[0];
+ }
+ }
+
+ /** Determines whether the authentication exchange has completed. */
+ public synchronized boolean isComplete() {
+ return saslClient != null && saslClient.isComplete();
+ }
+
+ /**
+ * Respond to server's SASL token.
+ * @param token contains server's SASL token
+ * @return client's response SASL token
+ */
+ public synchronized byte[] response(byte[] token) {
+ try {
+ return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0];
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslClient might be using.
+ */
+ public synchronized void dispose() {
+ if (saslClient != null) {
+ try {
+ saslClient.dispose();
+ } catch (SaslException e) {
+ // ignore
+ } finally {
+ saslClient = null;
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * that works with share secrets.
+ */
+ private class ClientCallbackHandler implements CallbackHandler {
+ @Override
+ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+
+ for (Callback callback : callbacks) {
+ if (callback instanceof NameCallback) {
+ logger.trace("SASL client callback: setting username");
+ NameCallback nc = (NameCallback) callback;
+ nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
+ } else if (callback instanceof PasswordCallback) {
+ logger.trace("SASL client callback: setting password");
+ PasswordCallback pc = (PasswordCallback) callback;
+ pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
+ } else if (callback instanceof RealmCallback) {
+ logger.trace("SASL client callback: setting realm");
+ RealmCallback rc = (RealmCallback) callback;
+ rc.setText(rc.getDefaultText());
+ logger.info("Realm callback");
+ } else if (callback instanceof RealmChoiceCallback) {
+ // ignore (?)
+ } else {
+ throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
+ }
+ }
+ }
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
new file mode 100644
index 0000000000..2c0ce40c75
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.AuthorizeCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+import java.io.IOException;
+import java.util.Map;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.io.BaseEncoding;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the
+ * initial state to the "authenticated" state. (It is not a server in the sense of accepting
+ * connections on some socket.)
+ */
+public class SparkSaslServer {
+ private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class);
+
+ /**
+ * This is passed as the server name when creating the sasl client/server.
+ * This could be changed to be configurable in the future.
+ */
+ static final String DEFAULT_REALM = "default";
+
+ /**
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ static final String DIGEST = "DIGEST-MD5";
+
+ /**
+ * The quality of protection is just "auth". This means that we are doing
+ * authentication only, we are not supporting integrity or privacy protection of the
+ * communication channel after authentication. This could be changed to be configurable
+ * in the future.
+ */
+ static final Map<String, String> SASL_PROPS = ImmutableMap.<String, String>builder()
+ .put(Sasl.QOP, "auth")
+ .put(Sasl.SERVER_AUTH, "true")
+ .build();
+
+ /** Identifier for a certain secret key within the secretKeyHolder. */
+ private final String secretKeyId;
+ private final SecretKeyHolder secretKeyHolder;
+ private SaslServer saslServer;
+
+ public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) {
+ this.secretKeyId = secretKeyId;
+ this.secretKeyHolder = secretKeyHolder;
+ try {
+ this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS,
+ new DigestCallbackHandler());
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Determines whether the authentication exchange has completed successfully.
+ */
+ public synchronized boolean isComplete() {
+ return saslServer != null && saslServer.isComplete();
+ }
+
+ /**
+ * Used to respond to server SASL tokens.
+ * @param token Server's SASL token
+ * @return response to send back to the server.
+ */
+ public synchronized byte[] response(byte[] token) {
+ try {
+ return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0];
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslServer might be using.
+ */
+ public synchronized void dispose() {
+ if (saslServer != null) {
+ try {
+ saslServer.dispose();
+ } catch (SaslException e) {
+ // ignore
+ } finally {
+ saslServer = null;
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism.
+ */
+ private class DigestCallbackHandler implements CallbackHandler {
+ @Override
+ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+ for (Callback callback : callbacks) {
+ if (callback instanceof NameCallback) {
+ logger.trace("SASL server callback: setting username");
+ NameCallback nc = (NameCallback) callback;
+ nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
+ } else if (callback instanceof PasswordCallback) {
+ logger.trace("SASL server callback: setting password");
+ PasswordCallback pc = (PasswordCallback) callback;
+ pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
+ } else if (callback instanceof RealmCallback) {
+ logger.trace("SASL server callback: setting realm");
+ RealmCallback rc = (RealmCallback) callback;
+ rc.setText(rc.getDefaultText());
+ } else if (callback instanceof AuthorizeCallback) {
+ AuthorizeCallback ac = (AuthorizeCallback) callback;
+ String authId = ac.getAuthenticationID();
+ String authzId = ac.getAuthorizationID();
+ ac.setAuthorized(authId.equals(authzId));
+ if (ac.isAuthorized()) {
+ ac.setAuthorizedID(authzId);
+ }
+ logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized());
+ } else {
+ throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
+ }
+ }
+ }
+ }
+
+ /* Encode a byte[] identifier as a Base64-encoded string. */
+ public static String encodeIdentifier(String identifier) {
+ Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled");
+ return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8));
+ }
+
+ /** Encode a password as a base64-encoded char[] array. */
+ public static char[] encodePassword(String password) {
+ Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled");
+ return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray();
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index a9dff31dec..cd3fea85b1 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -41,7 +41,7 @@ import org.apache.spark.network.util.JavaUtils;
* with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark-
* level shuffle block.
*/
-public class ExternalShuffleBlockHandler implements RpcHandler {
+public class ExternalShuffleBlockHandler extends RpcHandler {
private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class);
private final ExternalShuffleBlockManager blockManager;
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 6bbabc44b9..b0b19ba67b 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -17,8 +17,6 @@
package org.apache.spark.network.shuffle;
-import java.io.Closeable;
-
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -36,15 +34,20 @@ import org.apache.spark.network.util.TransportConf;
* BlockTransferService), which has the downside of losing the shuffle data if we lose the
* executors.
*/
-public class ExternalShuffleClient implements ShuffleClient {
+public class ExternalShuffleClient extends ShuffleClient {
private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class);
private final TransportClientFactory clientFactory;
- private final String appId;
- public ExternalShuffleClient(TransportConf conf, String appId) {
+ private String appId;
+
+ public ExternalShuffleClient(TransportConf conf) {
TransportContext context = new TransportContext(conf, new NoOpRpcHandler());
this.clientFactory = context.createClientFactory();
+ }
+
+ @Override
+ public void init(String appId) {
this.appId = appId;
}
@@ -55,6 +58,7 @@ public class ExternalShuffleClient implements ShuffleClient {
String execId,
String[] blockIds,
BlockFetchingListener listener) {
+ assert appId != null : "Called before init()";
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
TransportClient client = clientFactory.createClient(host, port);
@@ -82,6 +86,7 @@ public class ExternalShuffleClient implements ShuffleClient {
int port,
String execId,
ExecutorShuffleInfo executorInfo) {
+ assert appId != null : "Called before init()";
TransportClient client = clientFactory.createClient(host, port);
byte[] registerExecutorMessage =
JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo));
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
index d46a562394..f72ab40690 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
@@ -20,7 +20,14 @@ package org.apache.spark.network.shuffle;
import java.io.Closeable;
/** Provides an interface for reading shuffle files, either from an Executor or external service. */
-public interface ShuffleClient extends Closeable {
+public abstract class ShuffleClient implements Closeable {
+
+ /**
+ * Initializes the ShuffleClient, specifying this Executor's appId.
+ * Must be called before any other method on the ShuffleClient.
+ */
+ public void init(String appId) { }
+
/**
* Fetch a sequence of blocks from a remote node asynchronously,
*
@@ -28,7 +35,7 @@ public interface ShuffleClient extends Closeable {
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
* the data of a block is fetched, rather than waiting for all blocks to be fetched.
*/
- public void fetchBlocks(
+ public abstract void fetchBlocks(
String host,
int port,
String execId,
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
new file mode 100644
index 0000000000..8478120786
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class SaslIntegrationSuite {
+ static ExternalShuffleBlockHandler handler;
+ static TransportServer server;
+ static TransportConf conf;
+ static TransportContext context;
+
+ TransportClientFactory clientFactory;
+
+ /** Provides a secret key holder which always returns the given secret key. */
+ static class TestSecretKeyHolder implements SecretKeyHolder {
+
+ private final String secretKey;
+
+ TestSecretKeyHolder(String secretKey) {
+ this.secretKey = secretKey;
+ }
+
+ @Override
+ public String getSaslUser(String appId) {
+ return "user";
+ }
+ @Override
+ public String getSecretKey(String appId) {
+ return secretKey;
+ }
+ }
+
+
+ @BeforeClass
+ public static void beforeAll() throws IOException {
+ SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key");
+ SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder);
+ conf = new TransportConf(new SystemPropertyConfigProvider());
+ context = new TransportContext(conf, handler);
+ server = context.createServer();
+ }
+
+
+ @AfterClass
+ public static void afterAll() {
+ server.close();
+ }
+
+ @After
+ public void afterEach() {
+ if (clientFactory != null) {
+ clientFactory.close();
+ clientFactory = null;
+ }
+ }
+
+ @Test
+ public void testGoodClient() {
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key"))));
+
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ String msg = "Hello, World!";
+ byte[] resp = client.sendRpcSync(msg.getBytes(), 1000);
+ assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg
+ }
+
+ @Test
+ public void testBadClient() {
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key"))));
+
+ try {
+ // Bootstrap should fail on startup.
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
+ }
+ }
+
+ @Test
+ public void testNoSaslClient() {
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList());
+
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ client.sendRpcSync(new byte[13], 1000);
+ fail("Should have failed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage"));
+ }
+
+ try {
+ // Guessing the right tag byte doesn't magically get you in...
+ client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000);
+ fail("Should have failed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException"));
+ }
+ }
+
+ @Test
+ public void testNoSaslServer() {
+ RpcHandler handler = new TestRpcHandler();
+ TransportContext context = new TransportContext(conf, handler);
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key"))));
+ TransportServer server = context.createServer();
+ try {
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation"));
+ } finally {
+ server.close();
+ }
+ }
+
+ /** RPC handler which simply responds with the message it received. */
+ public static class TestRpcHandler extends RpcHandler {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ callback.onSuccess(message);
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return new OneForOneStreamManager();
+ }
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
new file mode 100644
index 0000000000..67a07f38eb
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.util.Map;
+
+import com.google.common.collect.ImmutableMap;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+/**
+ * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes.
+ */
+public class SparkSaslSuite {
+
+ /** Provides a secret key holder which returns secret key == appId */
+ private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() {
+ @Override
+ public String getSaslUser(String appId) {
+ return "user";
+ }
+
+ @Override
+ public String getSecretKey(String appId) {
+ return appId;
+ }
+ };
+
+ @Test
+ public void testMatching() {
+ SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder);
+ SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder);
+
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+
+ byte[] clientMessage = client.firstToken();
+
+ while (!client.isComplete()) {
+ clientMessage = client.response(server.response(clientMessage));
+ }
+ assertTrue(server.isComplete());
+
+ // Disposal should invalidate
+ server.dispose();
+ assertFalse(server.isComplete());
+ client.dispose();
+ assertFalse(client.isComplete());
+ }
+
+
+ @Test
+ public void testNonMatching() {
+ SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder);
+ SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder);
+
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+
+ byte[] clientMessage = client.firstToken();
+
+ try {
+ while (!client.isComplete()) {
+ clientMessage = client.response(server.response(clientMessage));
+ }
+ fail("Should not have completed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage().contains("Mismatched response"));
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+ }
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index b3bcf5fd68..bc101f5384 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -135,7 +135,8 @@ public class ExternalShuffleIntegrationSuite {
final Semaphore requestsRemaining = new Semaphore(0);
- ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID);
+ ExternalShuffleClient client = new ExternalShuffleClient(conf);
+ client.init(APP_ID);
client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
new BlockFetchingListener() {
@Override
@@ -164,6 +165,7 @@ public class ExternalShuffleIntegrationSuite {
if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) {
fail("Timeout getting response from the server");
}
+ client.close();
return res;
}
@@ -265,7 +267,8 @@ public class ExternalShuffleIntegrationSuite {
}
private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) {
- ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID);
+ ExternalShuffleClient client = new ExternalShuffleClient(conf);
+ client.init(APP_ID);
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
executorId, executorInfo);
}