diff options
author | Aaron Davidson <aaron@databricks.com> | 2014-11-04 16:15:38 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2014-11-04 16:15:38 -0800 |
commit | 5e73138a0152b78380b3f1def4b969b58e70dd11 (patch) | |
tree | d899e389f29e7c87c8f40e84872477a9b5e52277 /network/common | |
parent | f90ad5d426cb726079c490a9bb4b1100e2b4e602 (diff) | |
download | spark-5e73138a0152b78380b3f1def4b969b58e70dd11.tar.gz spark-5e73138a0152b78380b3f1def4b969b58e70dd11.tar.bz2 spark-5e73138a0152b78380b3f1def4b969b58e70dd11.zip |
[SPARK-2938] Support SASL authentication in NettyBlockTransferService
Also lays the groundwork for supporting it inside the external shuffle service.
Author: Aaron Davidson <aaron@databricks.com>
Closes #3087 from aarondav/sasl and squashes the following commits:
3481718 [Aaron Davidson] Delete rogue println
44f8410 [Aaron Davidson] Delete documentation - muahaha!
eb9f065 [Aaron Davidson] Improve documentation and add end-to-end test at Spark-level
a6b95f1 [Aaron Davidson] Address comments
785bbde [Aaron Davidson] Cleanup
79973cb [Aaron Davidson] Remove unused file
151b3c5 [Aaron Davidson] Add docs, timeout config, better failure handling
f6177d7 [Aaron Davidson] Cleanup SASL state upon connection termination
7b42adb [Aaron Davidson] Add unit tests
8191bcb [Aaron Davidson] [SPARK-2938] Support SASL authentication in NettyBlockTransferService
Diffstat (limited to 'network/common')
8 files changed, 125 insertions, 22 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index a271841e4e..5bc6e5a241 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -17,12 +17,16 @@ package org.apache.spark.network; +import java.util.List; + +import com.google.common.collect.Lists; import io.netty.channel.Channel; import io.netty.channel.socket.SocketChannel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.MessageDecoder; @@ -64,8 +68,17 @@ public class TransportContext { this.decoder = new MessageDecoder(); } + /** + * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning + * a new Client. Bootstraps will be executed synchronously, and must run successfully in order + * to create a Client. + */ + public TransportClientFactory createClientFactory(List<TransportClientBootstrap> bootstraps) { + return new TransportClientFactory(this, bootstraps); + } + public TransportClientFactory createClientFactory() { - return new TransportClientFactory(this); + return createClientFactory(Lists.<TransportClientBootstrap>newArrayList()); } /** Create a server which will attempt to bind to a specific port. */ diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 01c143fff4..a08cee02dd 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -19,10 +19,9 @@ package org.apache.spark.network.client; import java.io.Closeable; import java.util.UUID; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.util.concurrent.SettableFuture; @@ -186,4 +185,12 @@ public class TransportClient implements Closeable { // close is a local operation and should finish with milliseconds; timeout just to be safe channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("remoteAdress", channel.remoteAddress()) + .add("isActive", isActive()) + .toString(); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java new file mode 100644 index 0000000000..65e8020e34 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * A bootstrap which is executed on a TransportClient before it is returned to the user. + * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per- + * connection basis. + * + * Since connections (and TransportClients) are reused as much as possible, it is generally + * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with + * the JVM itself. + */ +public interface TransportClientBootstrap { + /** Performs the bootstrapping operation, throwing an exception on failure. */ + public void doBootstrap(TransportClient client) throws RuntimeException; +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 0b4a1d8286..1723fed307 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -21,10 +21,14 @@ import java.io.Closeable; import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.Lists; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.Channel; @@ -40,6 +44,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -47,22 +52,29 @@ import org.apache.spark.network.util.TransportConf; * Factory for creating {@link TransportClient}s by using createClient. * * The factory maintains a connection pool to other hosts and should return the same - * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for - * all {@link TransportClient}s. + * TransportClient for the same remote host. It also shares a single worker thread pool for + * all TransportClients. + * + * TransportClients will be reused whenever possible. Prior to completing the creation of a new + * TransportClient, all given {@link TransportClientBootstrap}s will be run. */ public class TransportClientFactory implements Closeable { private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); private final TransportContext context; private final TransportConf conf; + private final List<TransportClientBootstrap> clientBootstraps; private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool; private final Class<? extends Channel> socketChannelClass; private EventLoopGroup workerGroup; - public TransportClientFactory(TransportContext context) { - this.context = context; + public TransportClientFactory( + TransportContext context, + List<TransportClientBootstrap> clientBootstraps) { + this.context = Preconditions.checkNotNull(context); this.conf = context.getConf(); + this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>(); IOMode ioMode = IOMode.valueOf(conf.ioMode()); @@ -72,9 +84,12 @@ public class TransportClientFactory implements Closeable { } /** - * Create a new BlockFetchingClient connecting to the given remote host / port. + * Create a new {@link TransportClient} connecting to the given remote host / port. This will + * reuse TransportClients if they are still active and are for the same remote address. Prior + * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s + * that are registered with this factory. * - * This blocks until a connection is successfully established. + * This blocks until a connection is successfully established and fully bootstrapped. * * Concurrency: This method is safe to call from multiple threads. */ @@ -104,17 +119,18 @@ public class TransportClientFactory implements Closeable { // Use pooled buffers to reduce temporary buffer allocation bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); - final AtomicReference<TransportClient> client = new AtomicReference<TransportClient>(); + final AtomicReference<TransportClient> clientRef = new AtomicReference<TransportClient>(); bootstrap.handler(new ChannelInitializer<SocketChannel>() { @Override public void initChannel(SocketChannel ch) { TransportChannelHandler clientHandler = context.initializePipeline(ch); - client.set(clientHandler.getClient()); + clientRef.set(clientHandler.getClient()); } }); // Connect to the remote server + long preConnect = System.currentTimeMillis(); ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { throw new RuntimeException( @@ -123,15 +139,35 @@ public class TransportClientFactory implements Closeable { throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); } - // Successful connection -- in the event that two threads raced to create a client, we will + TransportClient client = clientRef.get(); + assert client != null : "Channel future completed successfully with null client"; + + // Execute any client bootstraps synchronously before marking the Client as successful. + long preBootstrap = System.currentTimeMillis(); + logger.debug("Connection to {} successful, running bootstraps...", address); + try { + for (TransportClientBootstrap clientBootstrap : clientBootstraps) { + clientBootstrap.doBootstrap(client); + } + } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala + long bootstrapTime = System.currentTimeMillis() - preBootstrap; + logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e); + client.close(); + throw Throwables.propagate(e); + } + long postBootstrap = System.currentTimeMillis(); + + // Successful connection & bootstrap -- in the event that two threads raced to create a client, // use the first one that was put into the connectionPool and close the one we made here. - assert client.get() != null : "Channel future completed successfully with null client"; - TransportClient oldClient = connectionPool.putIfAbsent(address, client.get()); + TransportClient oldClient = connectionPool.putIfAbsent(address, client); if (oldClient == null) { - return client.get(); + logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", + address, postBootstrap - preConnect, postBootstrap - preBootstrap); + return client; } else { - logger.debug("Two clients were created concurrently, second one will be disposed."); - client.get().close(); + logger.debug("Two clients were created concurrently after {} ms, second will be disposed.", + postBootstrap - preConnect); + client.close(); return oldClient; } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 5a3f003726..1502b7489e 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -21,7 +21,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; /** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */ -public class NoOpRpcHandler implements RpcHandler { +public class NoOpRpcHandler extends RpcHandler { private final StreamManager streamManager; public NoOpRpcHandler() { diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 2369dc6203..2ba92a40f8 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -23,22 +23,33 @@ import org.apache.spark.network.client.TransportClient; /** * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. */ -public interface RpcHandler { +public abstract class RpcHandler { /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. * + * This method will not be called in parallel for a single TransportClient (i.e., channel). + * * @param client A channel client which enables the handler to make requests back to the sender - * of this RPC. + * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. * @param callback Callback which should be invoked exactly once upon success or failure of the * RPC. */ - void receive(TransportClient client, byte[] message, RpcResponseCallback callback); + public abstract void receive( + TransportClient client, + byte[] message, + RpcResponseCallback callback); /** * Returns the StreamManager which contains the state about which streams are currently being * fetched by a TransportClient. */ - StreamManager getStreamManager(); + public abstract StreamManager getStreamManager(); + + /** + * Invoked when the connection associated with the given client has been invalidated. + * No further requests will come from this client. + */ + public void connectionTerminated(TransportClient client) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 17fe9001b3..1580180cc1 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -86,6 +86,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { for (long streamId : streamIds) { streamManager.connectionTerminated(streamId); } + rpcHandler.connectionTerminated(reverseClient); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index a68f38e0e9..823790dd3c 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -55,4 +55,7 @@ public class TransportConf { /** Send buffer size (SO_SNDBUF). */ public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } + + /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ + public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); } } |