diff options
3 files changed, 180 insertions, 46 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 9afd5decd5..d26b9b4d60 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.List; +import java.util.Random; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; @@ -42,6 +43,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -56,12 +58,31 @@ import org.apache.spark.network.util.TransportConf; * TransportClient, all given {@link TransportClientBootstrap}s will be run. */ public class TransportClientFactory implements Closeable { + + /** A simple data structure to track the pool of clients between two peer nodes. */ + private static class ClientPool { + TransportClient[] clients; + Object[] locks; + + public ClientPool(int size) { + clients = new TransportClient[size]; + locks = new Object[size]; + for (int i = 0; i < size; i++) { + locks[i] = new Object(); + } + } + } + private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); private final TransportContext context; private final TransportConf conf; private final List<TransportClientBootstrap> clientBootstraps; - private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool; + private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool; + + /** Random number generator for picking connections between peers. */ + private final Random rand; + private final int numConnectionsPerPeer; private final Class<? extends Channel> socketChannelClass; private EventLoopGroup workerGroup; @@ -73,7 +94,9 @@ public class TransportClientFactory implements Closeable { this.context = Preconditions.checkNotNull(context); this.conf = context.getConf(); this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); - this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>(); + this.connectionPool = new ConcurrentHashMap<SocketAddress, ClientPool>(); + this.numConnectionsPerPeer = conf.numConnectionsPerPeer(); + this.rand = new Random(); IOMode ioMode = IOMode.valueOf(conf.ioMode()); this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); @@ -84,10 +107,14 @@ public class TransportClientFactory implements Closeable { } /** - * Create a new {@link TransportClient} connecting to the given remote host / port. This will - * reuse TransportClients if they are still active and are for the same remote address. Prior - * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s - * that are registered with this factory. + * Create a {@link TransportClient} connecting to the given remote host / port. + * + * We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer) + * and randomly picks one to use. If no client was previously created in the randomly selected + * spot, this function creates a new client and places it there. + * + * Prior to the creation of a new TransportClient, we will execute all + * {@link TransportClientBootstrap}s that are registered with this factory. * * This blocks until a connection is successfully established and fully bootstrapped. * @@ -97,23 +124,48 @@ public class TransportClientFactory implements Closeable { // Get connection from the connection pool first. // If it is not found or not active, create a new one. final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); - TransportClient cachedClient = connectionPool.get(address); - if (cachedClient != null) { - if (cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); - return cachedClient; - } else { - logger.info("Found inactive connection to {}, closing it.", address); - connectionPool.remove(address, cachedClient); // Remove inactive clients. + + // Create the ClientPool if we don't have it yet. + ClientPool clientPool = connectionPool.get(address); + if (clientPool == null) { + connectionPool.putIfAbsent(address, new ClientPool(numConnectionsPerPeer)); + clientPool = connectionPool.get(address); + } + + int clientIndex = rand.nextInt(numConnectionsPerPeer); + TransportClient cachedClient = clientPool.clients[clientIndex]; + + if (cachedClient != null && cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", address, cachedClient); + return cachedClient; + } + + // If we reach here, we don't have an existing connection open. Let's create a new one. + // Multiple threads might race here to create new connections. Keep only one of them active. + synchronized (clientPool.locks[clientIndex]) { + cachedClient = clientPool.clients[clientIndex]; + + if (cachedClient != null) { + if (cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", address, cachedClient); + return cachedClient; + } else { + logger.info("Found inactive connection to {}, creating a new one.", address); + } } + clientPool.clients[clientIndex] = createClient(address); + return clientPool.clients[clientIndex]; } + } + /** Create a completely new {@link TransportClient} to the remote address. */ + private TransportClient createClient(InetSocketAddress address) throws IOException { logger.debug("Creating new connection to " + address); Bootstrap bootstrap = new Bootstrap(); bootstrap.group(workerGroup) .channel(socketChannelClass) - // Disable Nagle's Algorithm since we don't want packets to wait + // Disable Nagle's Algorithm since we don't want packets to wait .option(ChannelOption.TCP_NODELAY, true) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()) @@ -130,7 +182,7 @@ public class TransportClientFactory implements Closeable { }); // Connect to the remote server - long preConnect = System.currentTimeMillis(); + long preConnect = System.nanoTime(); ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { throw new IOException( @@ -143,43 +195,37 @@ public class TransportClientFactory implements Closeable { assert client != null : "Channel future completed successfully with null client"; // Execute any client bootstraps synchronously before marking the Client as successful. - long preBootstrap = System.currentTimeMillis(); + long preBootstrap = System.nanoTime(); logger.debug("Connection to {} successful, running bootstraps...", address); try { for (TransportClientBootstrap clientBootstrap : clientBootstraps) { clientBootstrap.doBootstrap(client); } } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala - long bootstrapTime = System.currentTimeMillis() - preBootstrap; - logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e); + long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000; + logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e); client.close(); throw Throwables.propagate(e); } - long postBootstrap = System.currentTimeMillis(); - - // Successful connection & bootstrap -- in the event that two threads raced to create a client, - // use the first one that was put into the connectionPool and close the one we made here. - TransportClient oldClient = connectionPool.putIfAbsent(address, client); - if (oldClient == null) { - logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", - address, postBootstrap - preConnect, postBootstrap - preBootstrap); - return client; - } else { - logger.debug("Two clients were created concurrently after {} ms, second will be disposed.", - postBootstrap - preConnect); - client.close(); - return oldClient; - } + long postBootstrap = System.nanoTime(); + + logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", + address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000); + + return client; } /** Close all connections in the connection pool, and shutdown the worker thread pool. */ @Override public void close() { - for (TransportClient client : connectionPool.values()) { - try { - client.close(); - } catch (RuntimeException e) { - logger.warn("Ignoring exception during close", e); + // Go through all clients and close them if they are active. + for (ClientPool clientPool : connectionPool.values()) { + for (int i = 0; i < clientPool.clients.length; i++) { + TransportClient client = clientPool.clients[i]; + if (client != null) { + clientPool.clients[i] = null; + JavaUtils.closeQuietly(client); + } } } connectionPool.clear(); diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 1af40acf8b..f60573998f 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -40,6 +40,11 @@ public class TransportConf { return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000; } + /** Number of concurrent connections between two nodes for fetching data. **/ + public int numConnectionsPerPeer() { + return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 2); + } + /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 822bef1d81..416dc1b969 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -18,7 +18,11 @@ package org.apache.spark.network; import java.io.IOException; -import java.util.concurrent.TimeoutException; +import java.util.Collections; +import java.util.HashSet; +import java.util.NoSuchElementException; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Before; @@ -32,6 +36,7 @@ import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.ConfigProvider; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -57,16 +62,94 @@ public class TransportClientFactorySuite { JavaUtils.closeQuietly(server2); } + /** + * Request a bunch of clients to a single server to test + * we create up to maxConnections of clients. + * + * If concurrent is true, create multiple threads to create clients in parallel. + */ + private void testClientReuse(final int maxConnections, boolean concurrent) + throws IOException, InterruptedException { + TransportConf conf = new TransportConf(new ConfigProvider() { + @Override + public String get(String name) { + if (name.equals("spark.shuffle.io.numConnectionsPerPeer")) { + return Integer.toString(maxConnections); + } else { + throw new NoSuchElementException(); + } + } + }); + + RpcHandler rpcHandler = new NoOpRpcHandler(); + TransportContext context = new TransportContext(conf, rpcHandler); + final TransportClientFactory factory = context.createClientFactory(); + final Set<TransportClient> clients = Collections.synchronizedSet( + new HashSet<TransportClient>()); + + final AtomicInteger failed = new AtomicInteger(); + Thread[] attempts = new Thread[maxConnections * 10]; + + // Launch a bunch of threads to create new clients. + for (int i = 0; i < attempts.length; i++) { + attempts[i] = new Thread() { + @Override + public void run() { + try { + TransportClient client = + factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assert (client.isActive()); + clients.add(client); + } catch (IOException e) { + failed.incrementAndGet(); + } + } + }; + + if (concurrent) { + attempts[i].start(); + } else { + attempts[i].run(); + } + } + + // Wait until all the threads complete. + for (int i = 0; i < attempts.length; i++) { + attempts[i].join(); + } + + assert(failed.get() == 0); + assert(clients.size() == maxConnections); + + for (TransportClient client : clients) { + client.close(); + } + } + + @Test + public void reuseClientsUpToConfigVariable() throws Exception { + testClientReuse(1, false); + testClientReuse(2, false); + testClientReuse(3, false); + testClientReuse(4, false); + } + @Test - public void createAndReuseBlockClients() throws IOException { + public void reuseClientsUpToConfigVariableConcurrent() throws Exception { + testClientReuse(1, true); + testClientReuse(2, true); + testClientReuse(3, true); + testClientReuse(4, true); + } + + @Test + public void returnDifferentClientsForDifferentServers() throws IOException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - TransportClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); assertTrue(c1.isActive()); - assertTrue(c3.isActive()); - assertTrue(c1 == c2); - assertTrue(c1 != c3); + assertTrue(c2.isActive()); + assertTrue(c1 != c2); factory.close(); } |