aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2014-12-09 17:49:59 -0800
committerReynold Xin <rxin@databricks.com>2014-12-09 17:50:11 -0800
commit441ec3451730c7ae3dbef8952e313071d6147ab6 (patch)
treeb9f64470f2ba45ad0d99c8f102775846981a779a
parentb0d64e57255e5ca545c90f18bd9d10a07ae43759 (diff)
downloadspark-441ec3451730c7ae3dbef8952e313071d6147ab6.tar.gz
spark-441ec3451730c7ae3dbef8952e313071d6147ab6.tar.bz2
spark-441ec3451730c7ae3dbef8952e313071d6147ab6.zip
[SPARK-4740] Create multiple concurrent connections between two peer nodes in Netty.
It's been reported that when the number of disks is large and the number of nodes is small, Netty network throughput is low compared with NIO. We suspect the problem is that only a small number of disks are utilized to serve shuffle files at any given point, due to connection reuse. This patch adds a new config parameter to specify the number of concurrent connections between two peer nodes, default to 2. Author: Reynold Xin <rxin@databricks.com> Closes #3625 from rxin/SPARK-4740 and squashes the following commits: ad4241a [Reynold Xin] Updated javadoc. f33c72b [Reynold Xin] Code review feedback. 0fefabb [Reynold Xin] Use double check in synchronization. 41dfcb2 [Reynold Xin] Added test case. 9076b4a [Reynold Xin] Fixed two NPEs. 3e1306c [Reynold Xin] Minor style fix. 4f21673 [Reynold Xin] [SPARK-4740] Create multiple concurrent connections between two peer nodes in Netty. (cherry picked from commit 2b9b72682e587909a84d3ace214c22cec830eeaf) Signed-off-by: Reynold Xin <rxin@databricks.com>
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java124
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportConf.java5
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java97
3 files changed, 180 insertions, 46 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 9afd5decd5..d26b9b4d60 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -22,6 +22,7 @@ import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;
+import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
@@ -42,6 +43,7 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;
@@ -56,12 +58,31 @@ import org.apache.spark.network.util.TransportConf;
* TransportClient, all given {@link TransportClientBootstrap}s will be run.
*/
public class TransportClientFactory implements Closeable {
+
+ /** A simple data structure to track the pool of clients between two peer nodes. */
+ private static class ClientPool {
+ TransportClient[] clients;
+ Object[] locks;
+
+ public ClientPool(int size) {
+ clients = new TransportClient[size];
+ locks = new Object[size];
+ for (int i = 0; i < size; i++) {
+ locks[i] = new Object();
+ }
+ }
+ }
+
private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
private final TransportContext context;
private final TransportConf conf;
private final List<TransportClientBootstrap> clientBootstraps;
- private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool;
+ private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
+
+ /** Random number generator for picking connections between peers. */
+ private final Random rand;
+ private final int numConnectionsPerPeer;
private final Class<? extends Channel> socketChannelClass;
private EventLoopGroup workerGroup;
@@ -73,7 +94,9 @@ public class TransportClientFactory implements Closeable {
this.context = Preconditions.checkNotNull(context);
this.conf = context.getConf();
this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
- this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>();
+ this.connectionPool = new ConcurrentHashMap<SocketAddress, ClientPool>();
+ this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
+ this.rand = new Random();
IOMode ioMode = IOMode.valueOf(conf.ioMode());
this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
@@ -84,10 +107,14 @@ public class TransportClientFactory implements Closeable {
}
/**
- * Create a new {@link TransportClient} connecting to the given remote host / port. This will
- * reuse TransportClients if they are still active and are for the same remote address. Prior
- * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s
- * that are registered with this factory.
+ * Create a {@link TransportClient} connecting to the given remote host / port.
+ *
+ * We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer)
+ * and randomly picks one to use. If no client was previously created in the randomly selected
+ * spot, this function creates a new client and places it there.
+ *
+ * Prior to the creation of a new TransportClient, we will execute all
+ * {@link TransportClientBootstrap}s that are registered with this factory.
*
* This blocks until a connection is successfully established and fully bootstrapped.
*
@@ -97,23 +124,48 @@ public class TransportClientFactory implements Closeable {
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
- TransportClient cachedClient = connectionPool.get(address);
- if (cachedClient != null) {
- if (cachedClient.isActive()) {
- logger.trace("Returning cached connection to {}: {}", address, cachedClient);
- return cachedClient;
- } else {
- logger.info("Found inactive connection to {}, closing it.", address);
- connectionPool.remove(address, cachedClient); // Remove inactive clients.
+
+ // Create the ClientPool if we don't have it yet.
+ ClientPool clientPool = connectionPool.get(address);
+ if (clientPool == null) {
+ connectionPool.putIfAbsent(address, new ClientPool(numConnectionsPerPeer));
+ clientPool = connectionPool.get(address);
+ }
+
+ int clientIndex = rand.nextInt(numConnectionsPerPeer);
+ TransportClient cachedClient = clientPool.clients[clientIndex];
+
+ if (cachedClient != null && cachedClient.isActive()) {
+ logger.trace("Returning cached connection to {}: {}", address, cachedClient);
+ return cachedClient;
+ }
+
+ // If we reach here, we don't have an existing connection open. Let's create a new one.
+ // Multiple threads might race here to create new connections. Keep only one of them active.
+ synchronized (clientPool.locks[clientIndex]) {
+ cachedClient = clientPool.clients[clientIndex];
+
+ if (cachedClient != null) {
+ if (cachedClient.isActive()) {
+ logger.trace("Returning cached connection to {}: {}", address, cachedClient);
+ return cachedClient;
+ } else {
+ logger.info("Found inactive connection to {}, creating a new one.", address);
+ }
}
+ clientPool.clients[clientIndex] = createClient(address);
+ return clientPool.clients[clientIndex];
}
+ }
+ /** Create a completely new {@link TransportClient} to the remote address. */
+ private TransportClient createClient(InetSocketAddress address) throws IOException {
logger.debug("Creating new connection to " + address);
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(workerGroup)
.channel(socketChannelClass)
- // Disable Nagle's Algorithm since we don't want packets to wait
+ // Disable Nagle's Algorithm since we don't want packets to wait
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
@@ -130,7 +182,7 @@ public class TransportClientFactory implements Closeable {
});
// Connect to the remote server
- long preConnect = System.currentTimeMillis();
+ long preConnect = System.nanoTime();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
throw new IOException(
@@ -143,43 +195,37 @@ public class TransportClientFactory implements Closeable {
assert client != null : "Channel future completed successfully with null client";
// Execute any client bootstraps synchronously before marking the Client as successful.
- long preBootstrap = System.currentTimeMillis();
+ long preBootstrap = System.nanoTime();
logger.debug("Connection to {} successful, running bootstraps...", address);
try {
for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
clientBootstrap.doBootstrap(client);
}
} catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
- long bootstrapTime = System.currentTimeMillis() - preBootstrap;
- logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e);
+ long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
+ logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e);
client.close();
throw Throwables.propagate(e);
}
- long postBootstrap = System.currentTimeMillis();
-
- // Successful connection & bootstrap -- in the event that two threads raced to create a client,
- // use the first one that was put into the connectionPool and close the one we made here.
- TransportClient oldClient = connectionPool.putIfAbsent(address, client);
- if (oldClient == null) {
- logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
- address, postBootstrap - preConnect, postBootstrap - preBootstrap);
- return client;
- } else {
- logger.debug("Two clients were created concurrently after {} ms, second will be disposed.",
- postBootstrap - preConnect);
- client.close();
- return oldClient;
- }
+ long postBootstrap = System.nanoTime();
+
+ logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
+ address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);
+
+ return client;
}
/** Close all connections in the connection pool, and shutdown the worker thread pool. */
@Override
public void close() {
- for (TransportClient client : connectionPool.values()) {
- try {
- client.close();
- } catch (RuntimeException e) {
- logger.warn("Ignoring exception during close", e);
+ // Go through all clients and close them if they are active.
+ for (ClientPool clientPool : connectionPool.values()) {
+ for (int i = 0; i < clientPool.clients.length; i++) {
+ TransportClient client = clientPool.clients[i];
+ if (client != null) {
+ clientPool.clients[i] = null;
+ JavaUtils.closeQuietly(client);
+ }
}
}
connectionPool.clear();
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 1af40acf8b..f60573998f 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -40,6 +40,11 @@ public class TransportConf {
return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
}
+ /** Number of concurrent connections between two nodes for fetching data. **/
+ public int numConnectionsPerPeer() {
+ return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 2);
+ }
+
/** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */
public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); }
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
index 822bef1d81..416dc1b969 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -18,7 +18,11 @@
package org.apache.spark.network;
import java.io.IOException;
-import java.util.concurrent.TimeoutException;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.NoSuchElementException;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
import org.junit.After;
import org.junit.Before;
@@ -32,6 +36,7 @@ import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.ConfigProvider;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
@@ -57,16 +62,94 @@ public class TransportClientFactorySuite {
JavaUtils.closeQuietly(server2);
}
+ /**
+ * Request a bunch of clients to a single server to test
+ * we create up to maxConnections of clients.
+ *
+ * If concurrent is true, create multiple threads to create clients in parallel.
+ */
+ private void testClientReuse(final int maxConnections, boolean concurrent)
+ throws IOException, InterruptedException {
+ TransportConf conf = new TransportConf(new ConfigProvider() {
+ @Override
+ public String get(String name) {
+ if (name.equals("spark.shuffle.io.numConnectionsPerPeer")) {
+ return Integer.toString(maxConnections);
+ } else {
+ throw new NoSuchElementException();
+ }
+ }
+ });
+
+ RpcHandler rpcHandler = new NoOpRpcHandler();
+ TransportContext context = new TransportContext(conf, rpcHandler);
+ final TransportClientFactory factory = context.createClientFactory();
+ final Set<TransportClient> clients = Collections.synchronizedSet(
+ new HashSet<TransportClient>());
+
+ final AtomicInteger failed = new AtomicInteger();
+ Thread[] attempts = new Thread[maxConnections * 10];
+
+ // Launch a bunch of threads to create new clients.
+ for (int i = 0; i < attempts.length; i++) {
+ attempts[i] = new Thread() {
+ @Override
+ public void run() {
+ try {
+ TransportClient client =
+ factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ assert (client.isActive());
+ clients.add(client);
+ } catch (IOException e) {
+ failed.incrementAndGet();
+ }
+ }
+ };
+
+ if (concurrent) {
+ attempts[i].start();
+ } else {
+ attempts[i].run();
+ }
+ }
+
+ // Wait until all the threads complete.
+ for (int i = 0; i < attempts.length; i++) {
+ attempts[i].join();
+ }
+
+ assert(failed.get() == 0);
+ assert(clients.size() == maxConnections);
+
+ for (TransportClient client : clients) {
+ client.close();
+ }
+ }
+
+ @Test
+ public void reuseClientsUpToConfigVariable() throws Exception {
+ testClientReuse(1, false);
+ testClientReuse(2, false);
+ testClientReuse(3, false);
+ testClientReuse(4, false);
+ }
+
@Test
- public void createAndReuseBlockClients() throws IOException {
+ public void reuseClientsUpToConfigVariableConcurrent() throws Exception {
+ testClientReuse(1, true);
+ testClientReuse(2, true);
+ testClientReuse(3, true);
+ testClientReuse(4, true);
+ }
+
+ @Test
+ public void returnDifferentClientsForDifferentServers() throws IOException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
- TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
- TransportClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
+ TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
assertTrue(c1.isActive());
- assertTrue(c3.isActive());
- assertTrue(c1 == c2);
- assertTrue(c1 != c3);
+ assertTrue(c2.isActive());
+ assertTrue(c1 != c2);
factory.close();
}