aboutsummaryrefslogtreecommitdiff
path: root/network
diff options
context:
space:
mode:
Diffstat (limited to 'network')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java124
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportConf.java5
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java97
3 files changed, 180 insertions, 46 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 9afd5decd5..d26b9b4d60 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -22,6 +22,7 @@ import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;
+import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
@@ -42,6 +43,7 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;
@@ -56,12 +58,31 @@ import org.apache.spark.network.util.TransportConf;
* TransportClient, all given {@link TransportClientBootstrap}s will be run.
*/
public class TransportClientFactory implements Closeable {
+
+ /** A simple data structure to track the pool of clients between two peer nodes. */
+ private static class ClientPool {
+ TransportClient[] clients;
+ Object[] locks;
+
+ public ClientPool(int size) {
+ clients = new TransportClient[size];
+ locks = new Object[size];
+ for (int i = 0; i < size; i++) {
+ locks[i] = new Object();
+ }
+ }
+ }
+
private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
private final TransportContext context;
private final TransportConf conf;
private final List<TransportClientBootstrap> clientBootstraps;
- private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool;
+ private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
+
+ /** Random number generator for picking connections between peers. */
+ private final Random rand;
+ private final int numConnectionsPerPeer;
private final Class<? extends Channel> socketChannelClass;
private EventLoopGroup workerGroup;
@@ -73,7 +94,9 @@ public class TransportClientFactory implements Closeable {
this.context = Preconditions.checkNotNull(context);
this.conf = context.getConf();
this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
- this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>();
+ this.connectionPool = new ConcurrentHashMap<SocketAddress, ClientPool>();
+ this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
+ this.rand = new Random();
IOMode ioMode = IOMode.valueOf(conf.ioMode());
this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
@@ -84,10 +107,14 @@ public class TransportClientFactory implements Closeable {
}
/**
- * Create a new {@link TransportClient} connecting to the given remote host / port. This will
- * reuse TransportClients if they are still active and are for the same remote address. Prior
- * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s
- * that are registered with this factory.
+ * Create a {@link TransportClient} connecting to the given remote host / port.
+ *
+ * We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer)
+ * and randomly picks one to use. If no client was previously created in the randomly selected
+ * spot, this function creates a new client and places it there.
+ *
+ * Prior to the creation of a new TransportClient, we will execute all
+ * {@link TransportClientBootstrap}s that are registered with this factory.
*
* This blocks until a connection is successfully established and fully bootstrapped.
*
@@ -97,23 +124,48 @@ public class TransportClientFactory implements Closeable {
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
- TransportClient cachedClient = connectionPool.get(address);
- if (cachedClient != null) {
- if (cachedClient.isActive()) {
- logger.trace("Returning cached connection to {}: {}", address, cachedClient);
- return cachedClient;
- } else {
- logger.info("Found inactive connection to {}, closing it.", address);
- connectionPool.remove(address, cachedClient); // Remove inactive clients.
+
+ // Create the ClientPool if we don't have it yet.
+ ClientPool clientPool = connectionPool.get(address);
+ if (clientPool == null) {
+ connectionPool.putIfAbsent(address, new ClientPool(numConnectionsPerPeer));
+ clientPool = connectionPool.get(address);
+ }
+
+ int clientIndex = rand.nextInt(numConnectionsPerPeer);
+ TransportClient cachedClient = clientPool.clients[clientIndex];
+
+ if (cachedClient != null && cachedClient.isActive()) {
+ logger.trace("Returning cached connection to {}: {}", address, cachedClient);
+ return cachedClient;
+ }
+
+ // If we reach here, we don't have an existing connection open. Let's create a new one.
+ // Multiple threads might race here to create new connections. Keep only one of them active.
+ synchronized (clientPool.locks[clientIndex]) {
+ cachedClient = clientPool.clients[clientIndex];
+
+ if (cachedClient != null) {
+ if (cachedClient.isActive()) {
+ logger.trace("Returning cached connection to {}: {}", address, cachedClient);
+ return cachedClient;
+ } else {
+ logger.info("Found inactive connection to {}, creating a new one.", address);
+ }
}
+ clientPool.clients[clientIndex] = createClient(address);
+ return clientPool.clients[clientIndex];
}
+ }
+ /** Create a completely new {@link TransportClient} to the remote address. */
+ private TransportClient createClient(InetSocketAddress address) throws IOException {
logger.debug("Creating new connection to " + address);
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(workerGroup)
.channel(socketChannelClass)
- // Disable Nagle's Algorithm since we don't want packets to wait
+ // Disable Nagle's Algorithm since we don't want packets to wait
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
@@ -130,7 +182,7 @@ public class TransportClientFactory implements Closeable {
});
// Connect to the remote server
- long preConnect = System.currentTimeMillis();
+ long preConnect = System.nanoTime();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
throw new IOException(
@@ -143,43 +195,37 @@ public class TransportClientFactory implements Closeable {
assert client != null : "Channel future completed successfully with null client";
// Execute any client bootstraps synchronously before marking the Client as successful.
- long preBootstrap = System.currentTimeMillis();
+ long preBootstrap = System.nanoTime();
logger.debug("Connection to {} successful, running bootstraps...", address);
try {
for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
clientBootstrap.doBootstrap(client);
}
} catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
- long bootstrapTime = System.currentTimeMillis() - preBootstrap;
- logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e);
+ long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
+ logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e);
client.close();
throw Throwables.propagate(e);
}
- long postBootstrap = System.currentTimeMillis();
-
- // Successful connection & bootstrap -- in the event that two threads raced to create a client,
- // use the first one that was put into the connectionPool and close the one we made here.
- TransportClient oldClient = connectionPool.putIfAbsent(address, client);
- if (oldClient == null) {
- logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
- address, postBootstrap - preConnect, postBootstrap - preBootstrap);
- return client;
- } else {
- logger.debug("Two clients were created concurrently after {} ms, second will be disposed.",
- postBootstrap - preConnect);
- client.close();
- return oldClient;
- }
+ long postBootstrap = System.nanoTime();
+
+ logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
+ address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);
+
+ return client;
}
/** Close all connections in the connection pool, and shutdown the worker thread pool. */
@Override
public void close() {
- for (TransportClient client : connectionPool.values()) {
- try {
- client.close();
- } catch (RuntimeException e) {
- logger.warn("Ignoring exception during close", e);
+ // Go through all clients and close them if they are active.
+ for (ClientPool clientPool : connectionPool.values()) {
+ for (int i = 0; i < clientPool.clients.length; i++) {
+ TransportClient client = clientPool.clients[i];
+ if (client != null) {
+ clientPool.clients[i] = null;
+ JavaUtils.closeQuietly(client);
+ }
}
}
connectionPool.clear();
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 1af40acf8b..f60573998f 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -40,6 +40,11 @@ public class TransportConf {
return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
}
+ /** Number of concurrent connections between two nodes for fetching data. **/
+ public int numConnectionsPerPeer() {
+ return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 2);
+ }
+
/** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */
public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); }
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
index 822bef1d81..416dc1b969 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -18,7 +18,11 @@
package org.apache.spark.network;
import java.io.IOException;
-import java.util.concurrent.TimeoutException;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.NoSuchElementException;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
import org.junit.After;
import org.junit.Before;
@@ -32,6 +36,7 @@ import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.ConfigProvider;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
@@ -57,16 +62,94 @@ public class TransportClientFactorySuite {
JavaUtils.closeQuietly(server2);
}
+ /**
+ * Request a bunch of clients to a single server to test
+ * we create up to maxConnections of clients.
+ *
+ * If concurrent is true, create multiple threads to create clients in parallel.
+ */
+ private void testClientReuse(final int maxConnections, boolean concurrent)
+ throws IOException, InterruptedException {
+ TransportConf conf = new TransportConf(new ConfigProvider() {
+ @Override
+ public String get(String name) {
+ if (name.equals("spark.shuffle.io.numConnectionsPerPeer")) {
+ return Integer.toString(maxConnections);
+ } else {
+ throw new NoSuchElementException();
+ }
+ }
+ });
+
+ RpcHandler rpcHandler = new NoOpRpcHandler();
+ TransportContext context = new TransportContext(conf, rpcHandler);
+ final TransportClientFactory factory = context.createClientFactory();
+ final Set<TransportClient> clients = Collections.synchronizedSet(
+ new HashSet<TransportClient>());
+
+ final AtomicInteger failed = new AtomicInteger();
+ Thread[] attempts = new Thread[maxConnections * 10];
+
+ // Launch a bunch of threads to create new clients.
+ for (int i = 0; i < attempts.length; i++) {
+ attempts[i] = new Thread() {
+ @Override
+ public void run() {
+ try {
+ TransportClient client =
+ factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ assert (client.isActive());
+ clients.add(client);
+ } catch (IOException e) {
+ failed.incrementAndGet();
+ }
+ }
+ };
+
+ if (concurrent) {
+ attempts[i].start();
+ } else {
+ attempts[i].run();
+ }
+ }
+
+ // Wait until all the threads complete.
+ for (int i = 0; i < attempts.length; i++) {
+ attempts[i].join();
+ }
+
+ assert(failed.get() == 0);
+ assert(clients.size() == maxConnections);
+
+ for (TransportClient client : clients) {
+ client.close();
+ }
+ }
+
+ @Test
+ public void reuseClientsUpToConfigVariable() throws Exception {
+ testClientReuse(1, false);
+ testClientReuse(2, false);
+ testClientReuse(3, false);
+ testClientReuse(4, false);
+ }
+
@Test
- public void createAndReuseBlockClients() throws IOException {
+ public void reuseClientsUpToConfigVariableConcurrent() throws Exception {
+ testClientReuse(1, true);
+ testClientReuse(2, true);
+ testClientReuse(3, true);
+ testClientReuse(4, true);
+ }
+
+ @Test
+ public void returnDifferentClientsForDifferentServers() throws IOException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
- TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
- TransportClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
+ TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
assertTrue(c1.isActive());
- assertTrue(c3.isActive());
- assertTrue(c1 == c2);
- assertTrue(c1 != c3);
+ assertTrue(c2.isActive());
+ assertTrue(c1 != c2);
factory.close();
}