diff options
Diffstat (limited to 'network/common')
15 files changed, 164 insertions, 43 deletions
diff --git a/network/common/pom.xml b/network/common/pom.xml index a33e44b63d..ea887148d9 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -85,9 +85,25 @@ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory> <plugins> + <!-- Create a test-jar so network-shuffle can depend on our test utilities. --> <plugin> - <groupId>org.scalatest</groupId> - <artifactId>scalatest-maven-plugin</artifactId> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <version>2.2</version> + <executions> + <execution> + <goals> + <goal>test-jar</goal> + </goals> + </execution> + <execution> + <id>test-jar-on-test-compile</id> + <phase>test-compile</phase> + <goals> + <goal>test-jar</goal> + </goals> + </execution> + </executions> </plugin> </plugins> </build> diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 854aa6685f..a271841e4e 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -52,15 +52,13 @@ public class TransportContext { private final Logger logger = LoggerFactory.getLogger(TransportContext.class); private final TransportConf conf; - private final StreamManager streamManager; private final RpcHandler rpcHandler; private final MessageEncoder encoder; private final MessageDecoder decoder; - public TransportContext(TransportConf conf, StreamManager streamManager, RpcHandler rpcHandler) { + public TransportContext(TransportConf conf, RpcHandler rpcHandler) { this.conf = conf; - this.streamManager = streamManager; this.rpcHandler = rpcHandler; this.encoder = new MessageEncoder(); this.decoder = new MessageDecoder(); @@ -70,8 +68,14 @@ public class TransportContext { return new TransportClientFactory(this); } + /** Create a server which will attempt to bind to a specific port. */ + public TransportServer createServer(int port) { + return new TransportServer(this, port); + } + + /** Creates a new server, binding to any available ephemeral port. */ public TransportServer createServer() { - return new TransportServer(this); + return new TransportServer(this, 0); } /** @@ -109,7 +113,7 @@ public class TransportContext { TransportResponseHandler responseHandler = new TransportResponseHandler(channel); TransportClient client = new TransportClient(channel, responseHandler); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, - streamManager, rpcHandler); + rpcHandler); return new TransportChannelHandler(client, responseHandler, requestHandler); } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index b1732fcde2..01c143fff4 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -19,9 +19,13 @@ package org.apache.spark.network.client; import java.io.Closeable; import java.util.UUID; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.util.concurrent.SettableFuture; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; @@ -129,7 +133,7 @@ public class TransportClient implements Closeable { final long startTime = System.currentTimeMillis(); logger.trace("Sending RPC to {}", serverAddr); - final long requestId = UUID.randomUUID().getLeastSignificantBits(); + final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); handler.addRpcRequest(requestId, callback); channel.writeAndFlush(new RpcRequest(requestId, message)).addListener( @@ -151,6 +155,32 @@ public class TransportClient implements Closeable { }); } + /** + * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to + * a specified timeout for a response. + */ + public byte[] sendRpcSync(byte[] message, long timeoutMs) { + final SettableFuture<byte[]> result = SettableFuture.create(); + + sendRpc(message, new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + result.set(response); + } + + @Override + public void onFailure(Throwable e) { + result.setException(e); + } + }); + + try { + return result.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + @Override public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 10eb9ef7a0..e7fa4f6bf3 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -78,15 +78,17 @@ public class TransportClientFactory implements Closeable { * * Concurrency: This method is safe to call from multiple threads. */ - public TransportClient createClient(String remoteHost, int remotePort) throws TimeoutException { + public TransportClient createClient(String remoteHost, int remotePort) { // Get connection from the connection pool first. // If it is not found or not active, create a new one. final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); TransportClient cachedClient = connectionPool.get(address); - if (cachedClient != null && cachedClient.isActive()) { - return cachedClient; - } else if (cachedClient != null) { - connectionPool.remove(address, cachedClient); // Remove inactive clients. + if (cachedClient != null) { + if (cachedClient.isActive()) { + return cachedClient; + } else { + connectionPool.remove(address, cachedClient); // Remove inactive clients. + } } logger.debug("Creating new connection to " + address); @@ -115,13 +117,14 @@ public class TransportClientFactory implements Closeable { // Connect to the remote server ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { - throw new TimeoutException( + throw new RuntimeException( String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); } else if (cf.cause() != null) { throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); } - // Successful connection + // Successful connection -- in the event that two threads raced to create a client, we will + // use the first one that was put into the connectionPool and close the one we made here. assert client.get() != null : "Channel future completed successfully with null client"; TransportClient oldClient = connectionPool.putIfAbsent(address, client.get()); if (oldClient == null) { diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 7aa37efc58..5a3f003726 100644 --- a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -1,4 +1,6 @@ -package org.apache.spark.network;/* +package org.apache.spark.network.server; + +/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -17,12 +19,20 @@ package org.apache.spark.network;/* import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.server.RpcHandler; -/** Test RpcHandler which always returns a zero-sized success. */ +/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */ public class NoOpRpcHandler implements RpcHandler { + private final StreamManager streamManager; + + public NoOpRpcHandler() { + streamManager = new OneForOneStreamManager(); + } + @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - callback.onSuccess(new byte[0]); + throw new UnsupportedOperationException("Cannot handle messages"); } + + @Override + public StreamManager getStreamManager() { return streamManager; } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 9688705569..731d48d4d9 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -30,10 +30,10 @@ import org.apache.spark.network.buffer.ManagedBuffer; /** * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually - * fetched as chunks by the client. + * fetched as chunks by the client. Each registered buffer is one chunk. */ -public class DefaultStreamManager extends StreamManager { - private final Logger logger = LoggerFactory.getLogger(DefaultStreamManager.class); +public class OneForOneStreamManager extends StreamManager { + private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class); private final AtomicLong nextStreamId; private final Map<Long, StreamState> streams; @@ -51,7 +51,7 @@ public class DefaultStreamManager extends StreamManager { } } - public DefaultStreamManager() { + public OneForOneStreamManager() { // For debugging purposes, start with a random stream id to help identifying different streams. // This does not need to be globally unique, only unique to this class. nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index f54a696b8f..2369dc6203 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -35,4 +35,10 @@ public interface RpcHandler { * RPC. */ void receive(TransportClient client, byte[] message, RpcResponseCallback callback); + + /** + * Returns the StreamManager which contains the state about which streams are currently being + * fetched by a TransportClient. + */ + StreamManager getStreamManager(); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 352f865935..17fe9001b3 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -56,24 +56,23 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { /** Client on the same channel allowing us to talk back to the requester. */ private final TransportClient reverseClient; - /** Returns each chunk part of a stream. */ - private final StreamManager streamManager; - /** Handles all RPC messages. */ private final RpcHandler rpcHandler; + /** Returns each chunk part of a stream. */ + private final StreamManager streamManager; + /** List of all stream ids that have been read on this handler, used for cleanup. */ private final Set<Long> streamIds; public TransportRequestHandler( Channel channel, TransportClient reverseClient, - StreamManager streamManager, RpcHandler rpcHandler) { this.channel = channel; this.reverseClient = reverseClient; - this.streamManager = streamManager; this.rpcHandler = rpcHandler; + this.streamManager = rpcHandler.getStreamManager(); this.streamIds = Sets.newHashSet(); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 243070750d..d1a1877a98 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -49,11 +49,11 @@ public class TransportServer implements Closeable { private ChannelFuture channelFuture; private int port = -1; - public TransportServer(TransportContext context) { + public TransportServer(TransportContext context, int portToBind) { this.context = context; this.conf = context.getConf(); - init(); + init(portToBind); } public int getPort() { @@ -63,7 +63,7 @@ public class TransportServer implements Closeable { return port; } - private void init() { + private void init(int portToBind) { IOMode ioMode = IOMode.valueOf(conf.ioMode()); EventLoopGroup bossGroup = @@ -95,7 +95,7 @@ public class TransportServer implements Closeable { } }); - channelFuture = bootstrap.bind(new InetSocketAddress(conf.serverPort())); + channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); channelFuture.syncUninterruptibly(); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 32ba3f5b07..40b71b0c87 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -17,8 +17,12 @@ package org.apache.spark.network.util; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.Closeable; import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import com.google.common.io.Closeables; import org.slf4j.Logger; @@ -35,4 +39,38 @@ public class JavaUtils { logger.error("IOException should not have been thrown.", e); } } + + // TODO: Make this configurable, do not use Java serialization! + public static <T> T deserialize(byte[] bytes) { + try { + ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes)); + Object out = is.readObject(); + is.close(); + return (T) out; + } catch (ClassNotFoundException e) { + throw new RuntimeException("Could not deserialize object", e); + } catch (IOException e) { + throw new RuntimeException("Could not deserialize object", e); + } + } + + // TODO: Make this configurable, do not use Java serialization! + public static byte[] serialize(Object object) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream os = new ObjectOutputStream(baos); + os.writeObject(object); + os.close(); + return baos.toByteArray(); + } catch (IOException e) { + throw new RuntimeException("Could not serialize object", e); + } + } + + /** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */ + public static int nonNegativeHash(Object obj) { + if (obj == null) { return 0; } + int hash = obj.hashCode(); + return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0; + } } diff --git a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java index f4e0a2426a..5f20b70678 100644 --- a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java +++ b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network; +package org.apache.spark.network.util; import java.util.NoSuchElementException; diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 80f65d9803..a68f38e0e9 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -27,9 +27,6 @@ public class TransportConf { this.conf = conf; } - /** Port the server listens on. Default to a random port. */ - public int serverPort() { return conf.getInt("spark.shuffle.io.port", 0); } - /** IO mode: nio or epoll */ public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 738dca9b6a..c415883397 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -41,10 +41,13 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class ChunkFetchIntegrationSuite { @@ -93,7 +96,18 @@ public class ChunkFetchIntegrationSuite { } } }; - TransportContext context = new TransportContext(conf, streamManager, new NoOpRpcHandler()); + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + + @Override + public StreamManager getStreamManager() { + return streamManager; + } + }; + TransportContext context = new TransportContext(conf, handler); server = context.createServer(); clientFactory = context.createClientFactory(); } diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 9f216dd2d7..64b457b4b3 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -35,9 +35,11 @@ import static org.junit.Assert.*; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.DefaultStreamManager; +import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { @@ -61,8 +63,11 @@ public class RpcIntegrationSuite { throw new RuntimeException("Thrown: " + parts[1]); } } + + @Override + public StreamManager getStreamManager() { return new OneForOneStreamManager(); } }; - TransportContext context = new TransportContext(conf, new DefaultStreamManager(), rpcHandler); + TransportContext context = new TransportContext(conf, rpcHandler); server = context.createServer(); clientFactory = context.createClientFactory(); } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 3ef964616f..5a10fdb384 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -28,11 +28,11 @@ import static org.junit.Assert.assertTrue; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.DefaultStreamManager; +import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class TransportClientFactorySuite { @@ -44,9 +44,8 @@ public class TransportClientFactorySuite { @Before public void setUp() { conf = new TransportConf(new SystemPropertyConfigProvider()); - StreamManager streamManager = new DefaultStreamManager(); RpcHandler rpcHandler = new NoOpRpcHandler(); - context = new TransportContext(conf, streamManager, rpcHandler); + context = new TransportContext(conf, rpcHandler); server1 = context.createServer(); server2 = context.createServer(); } |