aboutsummaryrefslogtreecommitdiff
path: root/network/common
diff options
context:
space:
mode:
Diffstat (limited to 'network/common')
-rw-r--r--network/common/pom.xml20
-rw-r--r--network/common/src/main/java/org/apache/spark/network/TransportContext.java14
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java32
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java17
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java (renamed from network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java)18
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java (renamed from network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java)8
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java6
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportServer.java8
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java38
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java (renamed from network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java)2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportConf.java3
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java16
-rw-r--r--network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java9
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java7
15 files changed, 164 insertions, 43 deletions
diff --git a/network/common/pom.xml b/network/common/pom.xml
index a33e44b63d..ea887148d9 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -85,9 +85,25 @@
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
<testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
+ <!-- Create a test-jar so network-shuffle can depend on our test utilities. -->
<plugin>
- <groupId>org.scalatest</groupId>
- <artifactId>scalatest-maven-plugin</artifactId>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <version>2.2</version>
+ <executions>
+ <execution>
+ <goals>
+ <goal>test-jar</goal>
+ </goals>
+ </execution>
+ <execution>
+ <id>test-jar-on-test-compile</id>
+ <phase>test-compile</phase>
+ <goals>
+ <goal>test-jar</goal>
+ </goals>
+ </execution>
+ </executions>
</plugin>
</plugins>
</build>
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index 854aa6685f..a271841e4e 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -52,15 +52,13 @@ public class TransportContext {
private final Logger logger = LoggerFactory.getLogger(TransportContext.class);
private final TransportConf conf;
- private final StreamManager streamManager;
private final RpcHandler rpcHandler;
private final MessageEncoder encoder;
private final MessageDecoder decoder;
- public TransportContext(TransportConf conf, StreamManager streamManager, RpcHandler rpcHandler) {
+ public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
this.conf = conf;
- this.streamManager = streamManager;
this.rpcHandler = rpcHandler;
this.encoder = new MessageEncoder();
this.decoder = new MessageDecoder();
@@ -70,8 +68,14 @@ public class TransportContext {
return new TransportClientFactory(this);
}
+ /** Create a server which will attempt to bind to a specific port. */
+ public TransportServer createServer(int port) {
+ return new TransportServer(this, port);
+ }
+
+ /** Creates a new server, binding to any available ephemeral port. */
public TransportServer createServer() {
- return new TransportServer(this);
+ return new TransportServer(this, 0);
}
/**
@@ -109,7 +113,7 @@ public class TransportContext {
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
TransportClient client = new TransportClient(channel, responseHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
- streamManager, rpcHandler);
+ rpcHandler);
return new TransportChannelHandler(client, responseHandler, requestHandler);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index b1732fcde2..01c143fff4 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -19,9 +19,13 @@ package org.apache.spark.network.client;
import java.io.Closeable;
import java.util.UUID;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.util.concurrent.SettableFuture;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
@@ -129,7 +133,7 @@ public class TransportClient implements Closeable {
final long startTime = System.currentTimeMillis();
logger.trace("Sending RPC to {}", serverAddr);
- final long requestId = UUID.randomUUID().getLeastSignificantBits();
+ final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
handler.addRpcRequest(requestId, callback);
channel.writeAndFlush(new RpcRequest(requestId, message)).addListener(
@@ -151,6 +155,32 @@ public class TransportClient implements Closeable {
});
}
+ /**
+ * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
+ * a specified timeout for a response.
+ */
+ public byte[] sendRpcSync(byte[] message, long timeoutMs) {
+ final SettableFuture<byte[]> result = SettableFuture.create();
+
+ sendRpc(message, new RpcResponseCallback() {
+ @Override
+ public void onSuccess(byte[] response) {
+ result.set(response);
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ result.setException(e);
+ }
+ });
+
+ try {
+ return result.get(timeoutMs, TimeUnit.MILLISECONDS);
+ } catch (Exception e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
@Override
public void close() {
// close is a local operation and should finish with milliseconds; timeout just to be safe
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 10eb9ef7a0..e7fa4f6bf3 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -78,15 +78,17 @@ public class TransportClientFactory implements Closeable {
*
* Concurrency: This method is safe to call from multiple threads.
*/
- public TransportClient createClient(String remoteHost, int remotePort) throws TimeoutException {
+ public TransportClient createClient(String remoteHost, int remotePort) {
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
TransportClient cachedClient = connectionPool.get(address);
- if (cachedClient != null && cachedClient.isActive()) {
- return cachedClient;
- } else if (cachedClient != null) {
- connectionPool.remove(address, cachedClient); // Remove inactive clients.
+ if (cachedClient != null) {
+ if (cachedClient.isActive()) {
+ return cachedClient;
+ } else {
+ connectionPool.remove(address, cachedClient); // Remove inactive clients.
+ }
}
logger.debug("Creating new connection to " + address);
@@ -115,13 +117,14 @@ public class TransportClientFactory implements Closeable {
// Connect to the remote server
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
- throw new TimeoutException(
+ throw new RuntimeException(
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
} else if (cf.cause() != null) {
throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
}
- // Successful connection
+ // Successful connection -- in the event that two threads raced to create a client, we will
+ // use the first one that was put into the connectionPool and close the one we made here.
assert client.get() != null : "Channel future completed successfully with null client";
TransportClient oldClient = connectionPool.putIfAbsent(address, client.get());
if (oldClient == null) {
diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
index 7aa37efc58..5a3f003726 100644
--- a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
@@ -1,4 +1,6 @@
-package org.apache.spark.network;/*
+package org.apache.spark.network.server;
+
+/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
@@ -17,12 +19,20 @@ package org.apache.spark.network;/*
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.server.RpcHandler;
-/** Test RpcHandler which always returns a zero-sized success. */
+/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */
public class NoOpRpcHandler implements RpcHandler {
+ private final StreamManager streamManager;
+
+ public NoOpRpcHandler() {
+ streamManager = new OneForOneStreamManager();
+ }
+
@Override
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
- callback.onSuccess(new byte[0]);
+ throw new UnsupportedOperationException("Cannot handle messages");
}
+
+ @Override
+ public StreamManager getStreamManager() { return streamManager; }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index 9688705569..731d48d4d9 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -30,10 +30,10 @@ import org.apache.spark.network.buffer.ManagedBuffer;
/**
* StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually
- * fetched as chunks by the client.
+ * fetched as chunks by the client. Each registered buffer is one chunk.
*/
-public class DefaultStreamManager extends StreamManager {
- private final Logger logger = LoggerFactory.getLogger(DefaultStreamManager.class);
+public class OneForOneStreamManager extends StreamManager {
+ private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class);
private final AtomicLong nextStreamId;
private final Map<Long, StreamState> streams;
@@ -51,7 +51,7 @@ public class DefaultStreamManager extends StreamManager {
}
}
- public DefaultStreamManager() {
+ public OneForOneStreamManager() {
// For debugging purposes, start with a random stream id to help identifying different streams.
// This does not need to be globally unique, only unique to this class.
nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
index f54a696b8f..2369dc6203 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -35,4 +35,10 @@ public interface RpcHandler {
* RPC.
*/
void receive(TransportClient client, byte[] message, RpcResponseCallback callback);
+
+ /**
+ * Returns the StreamManager which contains the state about which streams are currently being
+ * fetched by a TransportClient.
+ */
+ StreamManager getStreamManager();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 352f865935..17fe9001b3 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -56,24 +56,23 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
/** Client on the same channel allowing us to talk back to the requester. */
private final TransportClient reverseClient;
- /** Returns each chunk part of a stream. */
- private final StreamManager streamManager;
-
/** Handles all RPC messages. */
private final RpcHandler rpcHandler;
+ /** Returns each chunk part of a stream. */
+ private final StreamManager streamManager;
+
/** List of all stream ids that have been read on this handler, used for cleanup. */
private final Set<Long> streamIds;
public TransportRequestHandler(
Channel channel,
TransportClient reverseClient,
- StreamManager streamManager,
RpcHandler rpcHandler) {
this.channel = channel;
this.reverseClient = reverseClient;
- this.streamManager = streamManager;
this.rpcHandler = rpcHandler;
+ this.streamManager = rpcHandler.getStreamManager();
this.streamIds = Sets.newHashSet();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
index 243070750d..d1a1877a98 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -49,11 +49,11 @@ public class TransportServer implements Closeable {
private ChannelFuture channelFuture;
private int port = -1;
- public TransportServer(TransportContext context) {
+ public TransportServer(TransportContext context, int portToBind) {
this.context = context;
this.conf = context.getConf();
- init();
+ init(portToBind);
}
public int getPort() {
@@ -63,7 +63,7 @@ public class TransportServer implements Closeable {
return port;
}
- private void init() {
+ private void init(int portToBind) {
IOMode ioMode = IOMode.valueOf(conf.ioMode());
EventLoopGroup bossGroup =
@@ -95,7 +95,7 @@ public class TransportServer implements Closeable {
}
});
- channelFuture = bootstrap.bind(new InetSocketAddress(conf.serverPort()));
+ channelFuture = bootstrap.bind(new InetSocketAddress(portToBind));
channelFuture.syncUninterruptibly();
port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
index 32ba3f5b07..40b71b0c87 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -17,8 +17,12 @@
package org.apache.spark.network.util;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
import com.google.common.io.Closeables;
import org.slf4j.Logger;
@@ -35,4 +39,38 @@ public class JavaUtils {
logger.error("IOException should not have been thrown.", e);
}
}
+
+ // TODO: Make this configurable, do not use Java serialization!
+ public static <T> T deserialize(byte[] bytes) {
+ try {
+ ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes));
+ Object out = is.readObject();
+ is.close();
+ return (T) out;
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException("Could not deserialize object", e);
+ } catch (IOException e) {
+ throw new RuntimeException("Could not deserialize object", e);
+ }
+ }
+
+ // TODO: Make this configurable, do not use Java serialization!
+ public static byte[] serialize(Object object) {
+ try {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ ObjectOutputStream os = new ObjectOutputStream(baos);
+ os.writeObject(object);
+ os.close();
+ return baos.toByteArray();
+ } catch (IOException e) {
+ throw new RuntimeException("Could not serialize object", e);
+ }
+ }
+
+ /** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */
+ public static int nonNegativeHash(Object obj) {
+ if (obj == null) { return 0; }
+ int hash = obj.hashCode();
+ return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0;
+ }
}
diff --git a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java
index f4e0a2426a..5f20b70678 100644
--- a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.network;
+package org.apache.spark.network.util;
import java.util.NoSuchElementException;
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 80f65d9803..a68f38e0e9 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -27,9 +27,6 @@ public class TransportConf {
this.conf = conf;
}
- /** Port the server listens on. Default to a random port. */
- public int serverPort() { return conf.getInt("spark.shuffle.io.port", 0); }
-
/** IO mode: nio or epoll */
public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); }
diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
index 738dca9b6a..c415883397 100644
--- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
@@ -41,10 +41,13 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class ChunkFetchIntegrationSuite {
@@ -93,7 +96,18 @@ public class ChunkFetchIntegrationSuite {
}
}
};
- TransportContext context = new TransportContext(conf, streamManager, new NoOpRpcHandler());
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return streamManager;
+ }
+ };
+ TransportContext context = new TransportContext(conf, handler);
server = context.createServer();
clientFactory = context.createClientFactory();
}
diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
index 9f216dd2d7..64b457b4b3 100644
--- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -35,9 +35,11 @@ import static org.junit.Assert.*;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
-import org.apache.spark.network.server.DefaultStreamManager;
+import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class RpcIntegrationSuite {
@@ -61,8 +63,11 @@ public class RpcIntegrationSuite {
throw new RuntimeException("Thrown: " + parts[1]);
}
}
+
+ @Override
+ public StreamManager getStreamManager() { return new OneForOneStreamManager(); }
};
- TransportContext context = new TransportContext(conf, new DefaultStreamManager(), rpcHandler);
+ TransportContext context = new TransportContext(conf, rpcHandler);
server = context.createServer();
clientFactory = context.createClientFactory();
}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
index 3ef964616f..5a10fdb384 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -28,11 +28,11 @@ import static org.junit.Assert.assertTrue;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
-import org.apache.spark.network.server.DefaultStreamManager;
+import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class TransportClientFactorySuite {
@@ -44,9 +44,8 @@ public class TransportClientFactorySuite {
@Before
public void setUp() {
conf = new TransportConf(new SystemPropertyConfigProvider());
- StreamManager streamManager = new DefaultStreamManager();
RpcHandler rpcHandler = new NoOpRpcHandler();
- context = new TransportContext(conf, streamManager, rpcHandler);
+ context = new TransportContext(conf, rpcHandler);
server1 = context.createServer();
server2 = context.createServer();
}