aboutsummaryrefslogtreecommitdiff
path: root/network/common
diff options
context:
space:
mode:
Diffstat (limited to 'network/common')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java16
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java13
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java3
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportServer.java8
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java14
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportConf.java17
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java7
8 files changed, 60 insertions, 20 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index a08cee02dd..4e944114e8 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -18,7 +18,9 @@
package org.apache.spark.network.client;
import java.io.Closeable;
+import java.io.IOException;
import java.util.UUID;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import com.google.common.base.Objects;
@@ -116,8 +118,12 @@ public class TransportClient implements Closeable {
serverAddr, future.cause());
logger.error(errorMsg, future.cause());
handler.removeFetchRequest(streamChunkId);
- callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause()));
channel.close();
+ try {
+ callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
}
}
});
@@ -147,8 +153,12 @@ public class TransportClient implements Closeable {
serverAddr, future.cause());
logger.error(errorMsg, future.cause());
handler.removeRpcRequest(requestId);
- callback.onFailure(new RuntimeException(errorMsg, future.cause()));
channel.close();
+ try {
+ callback.onFailure(new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
}
}
});
@@ -175,6 +185,8 @@ public class TransportClient implements Closeable {
try {
return result.get(timeoutMs, TimeUnit.MILLISECONDS);
+ } catch (ExecutionException e) {
+ throw Throwables.propagate(e.getCause());
} catch (Exception e) {
throw Throwables.propagate(e);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 1723fed307..397d3a8455 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -18,12 +18,12 @@
package org.apache.spark.network.client;
import java.io.Closeable;
+import java.io.IOException;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import com.google.common.base.Preconditions;
@@ -44,7 +44,6 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.util.IOMode;
-import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;
@@ -93,15 +92,17 @@ public class TransportClientFactory implements Closeable {
*
* Concurrency: This method is safe to call from multiple threads.
*/
- public TransportClient createClient(String remoteHost, int remotePort) {
+ public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
TransportClient cachedClient = connectionPool.get(address);
if (cachedClient != null) {
if (cachedClient.isActive()) {
+ logger.trace("Returning cached connection to {}: {}", address, cachedClient);
return cachedClient;
} else {
+ logger.info("Found inactive connection to {}, closing it.", address);
connectionPool.remove(address, cachedClient); // Remove inactive clients.
}
}
@@ -133,10 +134,10 @@ public class TransportClientFactory implements Closeable {
long preConnect = System.currentTimeMillis();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
- throw new RuntimeException(
+ throw new IOException(
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
} else if (cf.cause() != null) {
- throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
+ throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
}
TransportClient client = clientRef.get();
@@ -198,7 +199,7 @@ public class TransportClientFactory implements Closeable {
*/
private PooledByteBufAllocator createPooledByteBufAllocator() {
return new PooledByteBufAllocator(
- PlatformDependent.directBufferPreferred(),
+ conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(),
getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
getPrivateStaticField("DEFAULT_PAGE_SIZE"),
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index d8965590b3..2044afb0d8 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -17,6 +17,7 @@
package org.apache.spark.network.client;
+import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@@ -94,7 +95,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
String remoteAddress = NettyUtils.getRemoteAddress(channel);
logger.error("Still have {} requests outstanding when connection from {} is closed",
numOutstandingRequests(), remoteAddress);
- failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed"));
+ failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 4cb8becc3e..91d1e8a538 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -66,7 +66,7 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
// All messages have the frame length, message type, and message itself.
int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
long frameLength = headerLength + bodyLength;
- ByteBuf header = ctx.alloc().buffer(headerLength);
+ ByteBuf header = ctx.alloc().heapBuffer(headerLength);
header.writeLong(frameLength);
msgType.encode(header);
in.encode(header);
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
index 70da48ca8e..579676c2c3 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -28,6 +28,7 @@ import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.SocketChannel;
+import io.netty.util.internal.PlatformDependent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -71,11 +72,14 @@ public class TransportServer implements Closeable {
NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server");
EventLoopGroup workerGroup = bossGroup;
+ PooledByteBufAllocator allocator = new PooledByteBufAllocator(
+ conf.preferDirectBufs() && PlatformDependent.directBufferPreferred());
+
bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
.channel(NettyUtils.getServerChannelClass(ioMode))
- .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
- .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT);
+ .option(ChannelOption.ALLOCATOR, allocator)
+ .childOption(ChannelOption.ALLOCATOR, allocator);
if (conf.backLog() > 0) {
bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());
diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
index b187234119..2a7664fe89 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
@@ -37,13 +37,17 @@ import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
* Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO.
*/
public class NettyUtils {
- /** Creates a Netty EventLoopGroup based on the IOMode. */
- public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
-
- ThreadFactory threadFactory = new ThreadFactoryBuilder()
+ /** Creates a new ThreadFactory which prefixes each thread with the given name. */
+ public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
+ return new ThreadFactoryBuilder()
.setDaemon(true)
- .setNameFormat(threadPrefix + "-%d")
+ .setNameFormat(threadPoolPrefix + "-%d")
.build();
+ }
+
+ /** Creates a Netty EventLoopGroup based on the IOMode. */
+ public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
+ ThreadFactory threadFactory = createThreadFactory(threadPrefix);
switch (mode) {
case NIO:
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 823790dd3c..787a8f0031 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -30,6 +30,11 @@ public class TransportConf {
/** IO mode: nio or epoll */
public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); }
+ /** If true, we will prefer allocating off-heap byte buffers within Netty. */
+ public boolean preferDirectBufs() {
+ return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true);
+ }
+
/** Connect timeout in secs. Default 120 secs. */
public int connectionTimeoutMs() {
return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
@@ -58,4 +63,16 @@ public class TransportConf {
/** Timeout for a single round trip of SASL token exchange, in milliseconds. */
public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); }
+
+ /**
+ * Max number of times we will try IO exceptions (such as connection timeouts) per request.
+ * If set to 0, we will not do any retries.
+ */
+ public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); }
+
+ /**
+ * Time (in milliseconds) that we will wait in order to perform a retry after an IOException.
+ * Only relevant if maxIORetries > 0.
+ */
+ public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.retryWaitMs", 5000); }
}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
index 5a10fdb384..822bef1d81 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.network;
+import java.io.IOException;
import java.util.concurrent.TimeoutException;
import org.junit.After;
@@ -57,7 +58,7 @@ public class TransportClientFactorySuite {
}
@Test
- public void createAndReuseBlockClients() throws TimeoutException {
+ public void createAndReuseBlockClients() throws IOException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
@@ -70,7 +71,7 @@ public class TransportClientFactorySuite {
}
@Test
- public void neverReturnInactiveClients() throws Exception {
+ public void neverReturnInactiveClients() throws IOException, InterruptedException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
c1.close();
@@ -88,7 +89,7 @@ public class TransportClientFactorySuite {
}
@Test
- public void closeBlockClientsWithFactory() throws TimeoutException {
+ public void closeBlockClientsWithFactory() throws IOException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());