aboutsummaryrefslogtreecommitdiff
path: root/network
diff options
context:
space:
mode:
authorAaron Davidson <aaron@databricks.com>2014-11-01 14:37:45 -0700
committerReynold Xin <rxin@databricks.com>2014-11-01 14:37:45 -0700
commitf55218aeb1e9d638df6229b36a59a15ce5363482 (patch)
tree84e4454c224b3f14b7fcbe8259c90d06b6fd969b /network
parent1d4f3552037cb667971bea2e5078d8b3ce6c2eae (diff)
downloadspark-f55218aeb1e9d638df6229b36a59a15ce5363482.tar.gz
spark-f55218aeb1e9d638df6229b36a59a15ce5363482.tar.bz2
spark-f55218aeb1e9d638df6229b36a59a15ce5363482.zip
[SPARK-3796] Create external service which can serve shuffle files
This patch introduces the tooling necessary to construct an external shuffle service which is independent of Spark executors, and then use this service inside Spark. An example (just for the sake of this PR) of the service creation can be found in Worker, and the service itself is used by plugging in the StandaloneShuffleClient as Spark's ShuffleClient (setup in BlockManager). This PR continues the work from #2753, which extracted out the transport layer of Spark's block transfer into an independent package within Spark. A new package was created which contains the Spark business logic necessary to retrieve the actual shuffle data, which is completely independent of the transport layer introduced in the previous patch. Similar to the transport layer, this package must not depend on Spark as we anticipate plugging this service as a lightweight process within, say, the YARN NodeManager, and do not wish to include Spark's dependencies (including Scala itself). There are several outstanding tasks which must be complete before this PR can be merged: - [x] Complete unit testing of network/shuffle package. - [x] Performance and correctness testing on a real cluster. - [x] Remove example service instantiation from Worker.scala. There are even more shortcomings of this PR which should be addressed in followup patches: - Don't use Java serializer for RPC layer! It is not cross-version compatible. - Handle shuffle file cleanup for dead executors once the application terminates or the ContextCleaner triggers. - Documentation of the feature in the Spark docs. - Improve behavior if the shuffle service itself goes down (right now we don't blacklist it, and new executors cannot spawn on that machine). - SSL and SASL integration - Nice to have: Handle shuffle file consolidation (this would requires changes to Spark's implementation). Author: Aaron Davidson <aaron@databricks.com> Closes #3001 from aarondav/shuffle-service and squashes the following commits: 4d1f8c1 [Aaron Davidson] Remove changes to Worker 705748f [Aaron Davidson] Rename Standalone* to External* fd3928b [Aaron Davidson] Do not unregister executor outputs unduly 9883918 [Aaron Davidson] Make suggested build changes 3d62679 [Aaron Davidson] Add Spark integration test 7fe51d5 [Aaron Davidson] Fix SBT integration 56caa50 [Aaron Davidson] Address comments c8d1ac3 [Aaron Davidson] Add unit tests 2f70c0c [Aaron Davidson] Fix unit tests 5483e96 [Aaron Davidson] Fix unit tests 46a70bf [Aaron Davidson] Whoops, bracket 5ea4df6 [Aaron Davidson] [SPARK-3796] Create external service which can serve shuffle files
Diffstat (limited to 'network')
-rw-r--r--network/common/pom.xml20
-rw-r--r--network/common/src/main/java/org/apache/spark/network/TransportContext.java14
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java32
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java17
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java (renamed from network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java)18
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java (renamed from network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java)8
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java6
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportServer.java8
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java38
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java (renamed from network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java)2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportConf.java3
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java16
-rw-r--r--network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java9
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java7
-rw-r--r--network/shuffle/pom.xml96
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java36
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java64
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java102
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java154
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java88
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java106
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java121
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java35
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java60
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java123
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java125
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java291
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java167
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java51
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java107
31 files changed, 1890 insertions, 43 deletions
diff --git a/network/common/pom.xml b/network/common/pom.xml
index a33e44b63d..ea887148d9 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -85,9 +85,25 @@
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
<testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
+ <!-- Create a test-jar so network-shuffle can depend on our test utilities. -->
<plugin>
- <groupId>org.scalatest</groupId>
- <artifactId>scalatest-maven-plugin</artifactId>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <version>2.2</version>
+ <executions>
+ <execution>
+ <goals>
+ <goal>test-jar</goal>
+ </goals>
+ </execution>
+ <execution>
+ <id>test-jar-on-test-compile</id>
+ <phase>test-compile</phase>
+ <goals>
+ <goal>test-jar</goal>
+ </goals>
+ </execution>
+ </executions>
</plugin>
</plugins>
</build>
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index 854aa6685f..a271841e4e 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -52,15 +52,13 @@ public class TransportContext {
private final Logger logger = LoggerFactory.getLogger(TransportContext.class);
private final TransportConf conf;
- private final StreamManager streamManager;
private final RpcHandler rpcHandler;
private final MessageEncoder encoder;
private final MessageDecoder decoder;
- public TransportContext(TransportConf conf, StreamManager streamManager, RpcHandler rpcHandler) {
+ public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
this.conf = conf;
- this.streamManager = streamManager;
this.rpcHandler = rpcHandler;
this.encoder = new MessageEncoder();
this.decoder = new MessageDecoder();
@@ -70,8 +68,14 @@ public class TransportContext {
return new TransportClientFactory(this);
}
+ /** Create a server which will attempt to bind to a specific port. */
+ public TransportServer createServer(int port) {
+ return new TransportServer(this, port);
+ }
+
+ /** Creates a new server, binding to any available ephemeral port. */
public TransportServer createServer() {
- return new TransportServer(this);
+ return new TransportServer(this, 0);
}
/**
@@ -109,7 +113,7 @@ public class TransportContext {
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
TransportClient client = new TransportClient(channel, responseHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
- streamManager, rpcHandler);
+ rpcHandler);
return new TransportChannelHandler(client, responseHandler, requestHandler);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index b1732fcde2..01c143fff4 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -19,9 +19,13 @@ package org.apache.spark.network.client;
import java.io.Closeable;
import java.util.UUID;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.util.concurrent.SettableFuture;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
@@ -129,7 +133,7 @@ public class TransportClient implements Closeable {
final long startTime = System.currentTimeMillis();
logger.trace("Sending RPC to {}", serverAddr);
- final long requestId = UUID.randomUUID().getLeastSignificantBits();
+ final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
handler.addRpcRequest(requestId, callback);
channel.writeAndFlush(new RpcRequest(requestId, message)).addListener(
@@ -151,6 +155,32 @@ public class TransportClient implements Closeable {
});
}
+ /**
+ * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
+ * a specified timeout for a response.
+ */
+ public byte[] sendRpcSync(byte[] message, long timeoutMs) {
+ final SettableFuture<byte[]> result = SettableFuture.create();
+
+ sendRpc(message, new RpcResponseCallback() {
+ @Override
+ public void onSuccess(byte[] response) {
+ result.set(response);
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ result.setException(e);
+ }
+ });
+
+ try {
+ return result.get(timeoutMs, TimeUnit.MILLISECONDS);
+ } catch (Exception e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
@Override
public void close() {
// close is a local operation and should finish with milliseconds; timeout just to be safe
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 10eb9ef7a0..e7fa4f6bf3 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -78,15 +78,17 @@ public class TransportClientFactory implements Closeable {
*
* Concurrency: This method is safe to call from multiple threads.
*/
- public TransportClient createClient(String remoteHost, int remotePort) throws TimeoutException {
+ public TransportClient createClient(String remoteHost, int remotePort) {
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
TransportClient cachedClient = connectionPool.get(address);
- if (cachedClient != null && cachedClient.isActive()) {
- return cachedClient;
- } else if (cachedClient != null) {
- connectionPool.remove(address, cachedClient); // Remove inactive clients.
+ if (cachedClient != null) {
+ if (cachedClient.isActive()) {
+ return cachedClient;
+ } else {
+ connectionPool.remove(address, cachedClient); // Remove inactive clients.
+ }
}
logger.debug("Creating new connection to " + address);
@@ -115,13 +117,14 @@ public class TransportClientFactory implements Closeable {
// Connect to the remote server
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
- throw new TimeoutException(
+ throw new RuntimeException(
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
} else if (cf.cause() != null) {
throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
}
- // Successful connection
+ // Successful connection -- in the event that two threads raced to create a client, we will
+ // use the first one that was put into the connectionPool and close the one we made here.
assert client.get() != null : "Channel future completed successfully with null client";
TransportClient oldClient = connectionPool.putIfAbsent(address, client.get());
if (oldClient == null) {
diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
index 7aa37efc58..5a3f003726 100644
--- a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
@@ -1,4 +1,6 @@
-package org.apache.spark.network;/*
+package org.apache.spark.network.server;
+
+/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
@@ -17,12 +19,20 @@ package org.apache.spark.network;/*
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.server.RpcHandler;
-/** Test RpcHandler which always returns a zero-sized success. */
+/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */
public class NoOpRpcHandler implements RpcHandler {
+ private final StreamManager streamManager;
+
+ public NoOpRpcHandler() {
+ streamManager = new OneForOneStreamManager();
+ }
+
@Override
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
- callback.onSuccess(new byte[0]);
+ throw new UnsupportedOperationException("Cannot handle messages");
}
+
+ @Override
+ public StreamManager getStreamManager() { return streamManager; }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index 9688705569..731d48d4d9 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -30,10 +30,10 @@ import org.apache.spark.network.buffer.ManagedBuffer;
/**
* StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually
- * fetched as chunks by the client.
+ * fetched as chunks by the client. Each registered buffer is one chunk.
*/
-public class DefaultStreamManager extends StreamManager {
- private final Logger logger = LoggerFactory.getLogger(DefaultStreamManager.class);
+public class OneForOneStreamManager extends StreamManager {
+ private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class);
private final AtomicLong nextStreamId;
private final Map<Long, StreamState> streams;
@@ -51,7 +51,7 @@ public class DefaultStreamManager extends StreamManager {
}
}
- public DefaultStreamManager() {
+ public OneForOneStreamManager() {
// For debugging purposes, start with a random stream id to help identifying different streams.
// This does not need to be globally unique, only unique to this class.
nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
index f54a696b8f..2369dc6203 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -35,4 +35,10 @@ public interface RpcHandler {
* RPC.
*/
void receive(TransportClient client, byte[] message, RpcResponseCallback callback);
+
+ /**
+ * Returns the StreamManager which contains the state about which streams are currently being
+ * fetched by a TransportClient.
+ */
+ StreamManager getStreamManager();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 352f865935..17fe9001b3 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -56,24 +56,23 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
/** Client on the same channel allowing us to talk back to the requester. */
private final TransportClient reverseClient;
- /** Returns each chunk part of a stream. */
- private final StreamManager streamManager;
-
/** Handles all RPC messages. */
private final RpcHandler rpcHandler;
+ /** Returns each chunk part of a stream. */
+ private final StreamManager streamManager;
+
/** List of all stream ids that have been read on this handler, used for cleanup. */
private final Set<Long> streamIds;
public TransportRequestHandler(
Channel channel,
TransportClient reverseClient,
- StreamManager streamManager,
RpcHandler rpcHandler) {
this.channel = channel;
this.reverseClient = reverseClient;
- this.streamManager = streamManager;
this.rpcHandler = rpcHandler;
+ this.streamManager = rpcHandler.getStreamManager();
this.streamIds = Sets.newHashSet();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
index 243070750d..d1a1877a98 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -49,11 +49,11 @@ public class TransportServer implements Closeable {
private ChannelFuture channelFuture;
private int port = -1;
- public TransportServer(TransportContext context) {
+ public TransportServer(TransportContext context, int portToBind) {
this.context = context;
this.conf = context.getConf();
- init();
+ init(portToBind);
}
public int getPort() {
@@ -63,7 +63,7 @@ public class TransportServer implements Closeable {
return port;
}
- private void init() {
+ private void init(int portToBind) {
IOMode ioMode = IOMode.valueOf(conf.ioMode());
EventLoopGroup bossGroup =
@@ -95,7 +95,7 @@ public class TransportServer implements Closeable {
}
});
- channelFuture = bootstrap.bind(new InetSocketAddress(conf.serverPort()));
+ channelFuture = bootstrap.bind(new InetSocketAddress(portToBind));
channelFuture.syncUninterruptibly();
port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
index 32ba3f5b07..40b71b0c87 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -17,8 +17,12 @@
package org.apache.spark.network.util;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
import com.google.common.io.Closeables;
import org.slf4j.Logger;
@@ -35,4 +39,38 @@ public class JavaUtils {
logger.error("IOException should not have been thrown.", e);
}
}
+
+ // TODO: Make this configurable, do not use Java serialization!
+ public static <T> T deserialize(byte[] bytes) {
+ try {
+ ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes));
+ Object out = is.readObject();
+ is.close();
+ return (T) out;
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException("Could not deserialize object", e);
+ } catch (IOException e) {
+ throw new RuntimeException("Could not deserialize object", e);
+ }
+ }
+
+ // TODO: Make this configurable, do not use Java serialization!
+ public static byte[] serialize(Object object) {
+ try {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ ObjectOutputStream os = new ObjectOutputStream(baos);
+ os.writeObject(object);
+ os.close();
+ return baos.toByteArray();
+ } catch (IOException e) {
+ throw new RuntimeException("Could not serialize object", e);
+ }
+ }
+
+ /** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */
+ public static int nonNegativeHash(Object obj) {
+ if (obj == null) { return 0; }
+ int hash = obj.hashCode();
+ return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0;
+ }
}
diff --git a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java
index f4e0a2426a..5f20b70678 100644
--- a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.network;
+package org.apache.spark.network.util;
import java.util.NoSuchElementException;
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 80f65d9803..a68f38e0e9 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -27,9 +27,6 @@ public class TransportConf {
this.conf = conf;
}
- /** Port the server listens on. Default to a random port. */
- public int serverPort() { return conf.getInt("spark.shuffle.io.port", 0); }
-
/** IO mode: nio or epoll */
public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); }
diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
index 738dca9b6a..c415883397 100644
--- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
@@ -41,10 +41,13 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class ChunkFetchIntegrationSuite {
@@ -93,7 +96,18 @@ public class ChunkFetchIntegrationSuite {
}
}
};
- TransportContext context = new TransportContext(conf, streamManager, new NoOpRpcHandler());
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return streamManager;
+ }
+ };
+ TransportContext context = new TransportContext(conf, handler);
server = context.createServer();
clientFactory = context.createClientFactory();
}
diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
index 9f216dd2d7..64b457b4b3 100644
--- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -35,9 +35,11 @@ import static org.junit.Assert.*;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
-import org.apache.spark.network.server.DefaultStreamManager;
+import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class RpcIntegrationSuite {
@@ -61,8 +63,11 @@ public class RpcIntegrationSuite {
throw new RuntimeException("Thrown: " + parts[1]);
}
}
+
+ @Override
+ public StreamManager getStreamManager() { return new OneForOneStreamManager(); }
};
- TransportContext context = new TransportContext(conf, new DefaultStreamManager(), rpcHandler);
+ TransportContext context = new TransportContext(conf, rpcHandler);
server = context.createServer();
clientFactory = context.createClientFactory();
}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
index 3ef964616f..5a10fdb384 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -28,11 +28,11 @@ import static org.junit.Assert.assertTrue;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
-import org.apache.spark.network.server.DefaultStreamManager;
+import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class TransportClientFactorySuite {
@@ -44,9 +44,8 @@ public class TransportClientFactorySuite {
@Before
public void setUp() {
conf = new TransportConf(new SystemPropertyConfigProvider());
- StreamManager streamManager = new DefaultStreamManager();
RpcHandler rpcHandler = new NoOpRpcHandler();
- context = new TransportContext(conf, streamManager, rpcHandler);
+ context = new TransportContext(conf, rpcHandler);
server1 = context.createServer();
server2 = context.createServer();
}
diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml
new file mode 100644
index 0000000000..d271704d98
--- /dev/null
+++ b/network/shuffle/pom.xml
@@ -0,0 +1,96 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent</artifactId>
+ <version>1.2.0-SNAPSHOT</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-network-shuffle_2.10</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project Shuffle Streaming Service Code</name>
+ <url>http://spark.apache.org/</url>
+ <properties>
+ <sbt.project.name>network-shuffle</sbt.project.name>
+ </properties>
+
+ <dependencies>
+ <!-- Core dependencies -->
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-network-common_2.10</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ </dependency>
+
+ <!-- Provided dependencies -->
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <scope>provided</scope>
+ </dependency>
+
+ <!-- Test dependencies -->
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-network-common_2.10</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.novocode</groupId>
+ <artifactId>junit-interface</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+ </build>
+</project>
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java
new file mode 100644
index 0000000000..138fd5389c
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.util.EventListener;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+public interface BlockFetchingListener extends EventListener {
+ /**
+ * Called once per successfully fetched block. After this call returns, data will be released
+ * automatically. If the data will be passed to another thread, the receiver should retain()
+ * and release() the buffer on their own, or copy the data to a new buffer.
+ */
+ void onBlockFetchSuccess(String blockId, ManagedBuffer data);
+
+ /**
+ * Called at least once per block upon failures.
+ */
+ void onBlockFetchFailure(String blockId, Throwable exception);
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java
new file mode 100644
index 0000000000..d45e64656a
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+
+/** Contains all configuration necessary for locating the shuffle files of an executor. */
+public class ExecutorShuffleInfo implements Serializable {
+ /** The base set of local directories that the executor stores its shuffle files in. */
+ final String[] localDirs;
+ /** Number of subdirectories created within each localDir. */
+ final int subDirsPerLocalDir;
+ /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */
+ final String shuffleManager;
+
+ public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) {
+ this.localDirs = localDirs;
+ this.subDirsPerLocalDir = subDirsPerLocalDir;
+ this.shuffleManager = shuffleManager;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(subDirsPerLocalDir, shuffleManager) * 41 + Arrays.hashCode(localDirs);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("localDirs", Arrays.toString(localDirs))
+ .add("subDirsPerLocalDir", subDirsPerLocalDir)
+ .add("shuffleManager", shuffleManager)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof ExecutorShuffleInfo) {
+ ExecutorShuffleInfo o = (ExecutorShuffleInfo) other;
+ return Arrays.equals(localDirs, o.localDirs)
+ && Objects.equal(subDirsPerLocalDir, o.subDirsPerLocalDir)
+ && Objects.equal(shuffleManager, o.shuffleManager);
+ }
+ return false;
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
new file mode 100644
index 0000000000..a9dff31dec
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.util.List;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Lists;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.JavaUtils;
+
+/**
+ * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process.
+ *
+ * Handles registering executors and opening shuffle blocks from them. Shuffle blocks are registered
+ * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark-
+ * level shuffle block.
+ */
+public class ExternalShuffleBlockHandler implements RpcHandler {
+ private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class);
+
+ private final ExternalShuffleBlockManager blockManager;
+ private final OneForOneStreamManager streamManager;
+
+ public ExternalShuffleBlockHandler() {
+ this(new OneForOneStreamManager(), new ExternalShuffleBlockManager());
+ }
+
+ /** Enables mocking out the StreamManager and BlockManager. */
+ @VisibleForTesting
+ ExternalShuffleBlockHandler(
+ OneForOneStreamManager streamManager,
+ ExternalShuffleBlockManager blockManager) {
+ this.streamManager = streamManager;
+ this.blockManager = blockManager;
+ }
+
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ Object msgObj = JavaUtils.deserialize(message);
+
+ logger.trace("Received message: " + msgObj);
+
+ if (msgObj instanceof OpenShuffleBlocks) {
+ OpenShuffleBlocks msg = (OpenShuffleBlocks) msgObj;
+ List<ManagedBuffer> blocks = Lists.newArrayList();
+
+ for (String blockId : msg.blockIds) {
+ blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId));
+ }
+ long streamId = streamManager.registerStream(blocks.iterator());
+ logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
+ callback.onSuccess(JavaUtils.serialize(
+ new ShuffleStreamHandle(streamId, msg.blockIds.length)));
+
+ } else if (msgObj instanceof RegisterExecutor) {
+ RegisterExecutor msg = (RegisterExecutor) msgObj;
+ blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
+ callback.onSuccess(new byte[0]);
+
+ } else {
+ throw new UnsupportedOperationException(String.format(
+ "Unexpected message: %s (class = %s)", msgObj, msgObj.getClass()));
+ }
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return streamManager;
+ }
+
+ /** For testing, clears all executors registered with "RegisterExecutor". */
+ @VisibleForTesting
+ public void clearRegisteredExecutors() {
+ blockManager.clearRegisteredExecutors();
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java
new file mode 100644
index 0000000000..6589889fe1
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.DataInputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.util.concurrent.ConcurrentHashMap;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.util.JavaUtils;
+
+/**
+ * Manages converting shuffle BlockIds into physical segments of local files, from a process outside
+ * of Executors. Each Executor must register its own configuration about where it stores its files
+ * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated
+ * from Spark's FileShuffleBlockManager and IndexShuffleBlockManager.
+ *
+ * Executors with shuffle file consolidation are not currently supported, as the index is stored in
+ * the Executor's memory, unlike the IndexShuffleBlockManager.
+ */
+public class ExternalShuffleBlockManager {
+ private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockManager.class);
+
+ // Map from "appId-execId" to the executor's configuration.
+ private final ConcurrentHashMap<String, ExecutorShuffleInfo> executors =
+ new ConcurrentHashMap<String, ExecutorShuffleInfo>();
+
+ // Returns an id suitable for a single executor within a single application.
+ private String getAppExecId(String appId, String execId) {
+ return appId + "-" + execId;
+ }
+
+ /** Registers a new Executor with all the configuration we need to find its shuffle files. */
+ public void registerExecutor(
+ String appId,
+ String execId,
+ ExecutorShuffleInfo executorInfo) {
+ String fullId = getAppExecId(appId, execId);
+ logger.info("Registered executor {} with {}", fullId, executorInfo);
+ executors.put(fullId, executorInfo);
+ }
+
+ /**
+ * Obtains a FileSegmentManagedBuffer from a shuffle block id. We expect the blockId has the
+ * format "shuffle_ShuffleId_MapId_ReduceId" (from ShuffleBlockId), and additionally make
+ * assumptions about how the hash and sort based shuffles store their data.
+ */
+ public ManagedBuffer getBlockData(String appId, String execId, String blockId) {
+ String[] blockIdParts = blockId.split("_");
+ if (blockIdParts.length < 4) {
+ throw new IllegalArgumentException("Unexpected block id format: " + blockId);
+ } else if (!blockIdParts[0].equals("shuffle")) {
+ throw new IllegalArgumentException("Expected shuffle block id, got: " + blockId);
+ }
+ int shuffleId = Integer.parseInt(blockIdParts[1]);
+ int mapId = Integer.parseInt(blockIdParts[2]);
+ int reduceId = Integer.parseInt(blockIdParts[3]);
+
+ ExecutorShuffleInfo executor = executors.get(getAppExecId(appId, execId));
+ if (executor == null) {
+ throw new RuntimeException(
+ String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId));
+ }
+
+ if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) {
+ return getHashBasedShuffleBlockData(executor, blockId);
+ } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager)) {
+ return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId);
+ } else {
+ throw new UnsupportedOperationException(
+ "Unsupported shuffle manager: " + executor.shuffleManager);
+ }
+ }
+
+ /**
+ * Hash-based shuffle data is simply stored as one file per block.
+ * This logic is from FileShuffleBlockManager.
+ */
+ // TODO: Support consolidated hash shuffle files
+ private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) {
+ File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId);
+ return new FileSegmentManagedBuffer(shuffleFile, 0, shuffleFile.length());
+ }
+
+ /**
+ * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file
+ * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockManager,
+ * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId.
+ */
+ private ManagedBuffer getSortBasedShuffleBlockData(
+ ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) {
+ File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir,
+ "shuffle_" + shuffleId + "_" + mapId + "_0.index");
+
+ DataInputStream in = null;
+ try {
+ in = new DataInputStream(new FileInputStream(indexFile));
+ in.skipBytes(reduceId * 8);
+ long offset = in.readLong();
+ long nextOffset = in.readLong();
+ return new FileSegmentManagedBuffer(
+ getFile(executor.localDirs, executor.subDirsPerLocalDir,
+ "shuffle_" + shuffleId + "_" + mapId + "_0.data"),
+ offset,
+ nextOffset - offset);
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to open file: " + indexFile, e);
+ } finally {
+ if (in != null) {
+ JavaUtils.closeQuietly(in);
+ }
+ }
+ }
+
+ /**
+ * Hashes a filename into the corresponding local directory, in a manner consistent with
+ * Spark's DiskBlockManager.getFile().
+ */
+ @VisibleForTesting
+ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) {
+ int hash = JavaUtils.nonNegativeHash(filename);
+ String localDir = localDirs[hash % localDirs.length];
+ int subDirId = (hash / localDirs.length) % subDirsPerLocalDir;
+ return new File(new File(localDir, String.format("%02x", subDirId)), filename);
+ }
+
+ /** For testing, clears all registered executors. */
+ @VisibleForTesting
+ void clearRegisteredExecutors() {
+ executors.clear();
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
new file mode 100644
index 0000000000..cc2f6261ca
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.NoOpRpcHandler;
+import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Client for reading shuffle blocks which points to an external (outside of executor) server.
+ * This is instead of reading shuffle blocks directly from other executors (via
+ * BlockTransferService), which has the downside of losing the shuffle data if we lose the
+ * executors.
+ */
+public class ExternalShuffleClient implements ShuffleClient {
+ private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class);
+
+ private final TransportClientFactory clientFactory;
+ private final String appId;
+
+ public ExternalShuffleClient(TransportConf conf, String appId) {
+ TransportContext context = new TransportContext(conf, new NoOpRpcHandler());
+ this.clientFactory = context.createClientFactory();
+ this.appId = appId;
+ }
+
+ @Override
+ public void fetchBlocks(
+ String host,
+ int port,
+ String execId,
+ String[] blockIds,
+ BlockFetchingListener listener) {
+ logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
+ try {
+ TransportClient client = clientFactory.createClient(host, port);
+ new OneForOneBlockFetcher(client, blockIds, listener)
+ .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
+ } catch (Exception e) {
+ logger.error("Exception while beginning fetchBlocks", e);
+ for (String blockId : blockIds) {
+ listener.onBlockFetchFailure(blockId, e);
+ }
+ }
+ }
+
+ /**
+ * Registers this executor with an external shuffle server. This registration is required to
+ * inform the shuffle server about where and how we store our shuffle files.
+ *
+ * @param host Host of shuffle server.
+ * @param port Port of shuffle server.
+ * @param execId This Executor's id.
+ * @param executorInfo Contains all info necessary for the service to find our shuffle files.
+ */
+ public void registerWithShuffleServer(
+ String host,
+ int port,
+ String execId,
+ ExecutorShuffleInfo executorInfo) {
+ TransportClient client = clientFactory.createClient(host, port);
+ byte[] registerExecutorMessage =
+ JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo));
+ client.sendRpcSync(registerExecutorMessage, 5000 /* timeoutMs */);
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java
new file mode 100644
index 0000000000..e79420ed82
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+
+/** Messages handled by the {@link ExternalShuffleBlockHandler}. */
+public class ExternalShuffleMessages {
+
+ /** Request to read a set of shuffle blocks. Returns [[ShuffleStreamHandle]]. */
+ public static class OpenShuffleBlocks implements Serializable {
+ public final String appId;
+ public final String execId;
+ public final String[] blockIds;
+
+ public OpenShuffleBlocks(String appId, String execId, String[] blockIds) {
+ this.appId = appId;
+ this.execId = execId;
+ this.blockIds = blockIds;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("execId", execId)
+ .add("blockIds", Arrays.toString(blockIds))
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof OpenShuffleBlocks) {
+ OpenShuffleBlocks o = (OpenShuffleBlocks) other;
+ return Objects.equal(appId, o.appId)
+ && Objects.equal(execId, o.execId)
+ && Arrays.equals(blockIds, o.blockIds);
+ }
+ return false;
+ }
+ }
+
+ /** Initial registration message between an executor and its local shuffle server. */
+ public static class RegisterExecutor implements Serializable {
+ public final String appId;
+ public final String execId;
+ public final ExecutorShuffleInfo executorInfo;
+
+ public RegisterExecutor(
+ String appId,
+ String execId,
+ ExecutorShuffleInfo executorInfo) {
+ this.appId = appId;
+ this.execId = execId;
+ this.executorInfo = executorInfo;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, execId, executorInfo);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("execId", execId)
+ .add("executorInfo", executorInfo)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof RegisterExecutor) {
+ RegisterExecutor o = (RegisterExecutor) other;
+ return Objects.equal(appId, o.appId)
+ && Objects.equal(execId, o.execId)
+ && Objects.equal(executorInfo, o.executorInfo);
+ }
+ return false;
+ }
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
new file mode 100644
index 0000000000..39b6f30f92
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.util.Arrays;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.util.JavaUtils;
+
+/**
+ * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and
+ * invokes the BlockFetchingListener appropriately. This class is agnostic to the actual RPC
+ * handler, as long as there is a single "open blocks" message which returns a ShuffleStreamHandle,
+ * and Java serialization is used.
+ *
+ * Note that this typically corresponds to a
+ * {@link org.apache.spark.network.server.OneForOneStreamManager} on the server side.
+ */
+public class OneForOneBlockFetcher {
+ private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);
+
+ private final TransportClient client;
+ private final String[] blockIds;
+ private final BlockFetchingListener listener;
+ private final ChunkReceivedCallback chunkCallback;
+
+ private ShuffleStreamHandle streamHandle = null;
+
+ public OneForOneBlockFetcher(
+ TransportClient client,
+ String[] blockIds,
+ BlockFetchingListener listener) {
+ if (blockIds.length == 0) {
+ throw new IllegalArgumentException("Zero-sized blockIds array");
+ }
+ this.client = client;
+ this.blockIds = blockIds;
+ this.listener = listener;
+ this.chunkCallback = new ChunkCallback();
+ }
+
+ /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
+ private class ChunkCallback implements ChunkReceivedCallback {
+ @Override
+ public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ // On receipt of a chunk, pass it upwards as a block.
+ listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
+ }
+
+ @Override
+ public void onFailure(int chunkIndex, Throwable e) {
+ // On receipt of a failure, fail every block from chunkIndex onwards.
+ String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length);
+ failRemainingBlocks(remainingBlockIds, e);
+ }
+ }
+
+ /**
+ * Begins the fetching process, calling the listener with every block fetched.
+ * The given message will be serialized with the Java serializer, and the RPC must return a
+ * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling.
+ */
+ public void start(Object openBlocksMessage) {
+ client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() {
+ @Override
+ public void onSuccess(byte[] response) {
+ try {
+ streamHandle = JavaUtils.deserialize(response);
+ logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);
+
+ // Immediately request all chunks -- we expect that the total size of the request is
+ // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
+ for (int i = 0; i < streamHandle.numChunks; i++) {
+ client.fetchChunk(streamHandle.streamId, i, chunkCallback);
+ }
+ } catch (Exception e) {
+ logger.error("Failed while starting block fetches", e);
+ failRemainingBlocks(blockIds, e);
+ }
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ logger.error("Failed while starting block fetches", e);
+ failRemainingBlocks(blockIds, e);
+ }
+ });
+ }
+
+ /** Invokes the "onBlockFetchFailure" callback for every listed block id. */
+ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
+ for (String blockId : failedBlockIds) {
+ try {
+ listener.onBlockFetchFailure(blockId, e);
+ } catch (Exception e2) {
+ logger.error("Error in block fetch failure callback", e2);
+ }
+ }
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
new file mode 100644
index 0000000000..9fa87c2c6e
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+/** Provides an interface for reading shuffle files, either from an Executor or external service. */
+public interface ShuffleClient {
+ /**
+ * Fetch a sequence of blocks from a remote node asynchronously,
+ *
+ * Note that this API takes a sequence so the implementation can batch requests, and does not
+ * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
+ * the data of a block is fetched, rather than waiting for all blocks to be fetched.
+ */
+ public void fetchBlocks(
+ String host,
+ int port,
+ String execId,
+ String[] blockIds,
+ BlockFetchingListener listener);
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java
new file mode 100644
index 0000000000..9c94691224
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+
+/**
+ * Identifier for a fixed number of chunks to read from a stream created by an "open blocks"
+ * message. This is used by {@link OneForOneBlockFetcher}.
+ */
+public class ShuffleStreamHandle implements Serializable {
+ public final long streamId;
+ public final int numChunks;
+
+ public ShuffleStreamHandle(long streamId, int numChunks) {
+ this.streamId = streamId;
+ this.numChunks = numChunks;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId, numChunks);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("numChunks", numChunks)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof ShuffleStreamHandle) {
+ ShuffleStreamHandle o = (ShuffleStreamHandle) other;
+ return Objects.equal(streamId, o.streamId)
+ && Objects.equal(numChunks, o.numChunks);
+ }
+ return false;
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
new file mode 100644
index 0000000000..7939cb4d32
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.nio.ByteBuffer;
+import java.util.Iterator;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+
+import static org.apache.spark.network.shuffle.ExternalShuffleMessages.OpenShuffleBlocks;
+import static org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor;
+import static org.junit.Assert.*;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.util.JavaUtils;
+
+public class ExternalShuffleBlockHandlerSuite {
+ TransportClient client = mock(TransportClient.class);
+
+ OneForOneStreamManager streamManager;
+ ExternalShuffleBlockManager blockManager;
+ RpcHandler handler;
+
+ @Before
+ public void beforeEach() {
+ streamManager = mock(OneForOneStreamManager.class);
+ blockManager = mock(ExternalShuffleBlockManager.class);
+ handler = new ExternalShuffleBlockHandler(streamManager, blockManager);
+ }
+
+ @Test
+ public void testRegisterExecutor() {
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+
+ ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort");
+ byte[] registerMessage = JavaUtils.serialize(
+ new RegisterExecutor("app0", "exec1", config));
+ handler.receive(client, registerMessage, callback);
+ verify(blockManager, times(1)).registerExecutor("app0", "exec1", config);
+
+ verify(callback, times(1)).onSuccess((byte[]) any());
+ verify(callback, never()).onFailure((Throwable) any());
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testOpenShuffleBlocks() {
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+
+ ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3]));
+ ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
+ when(blockManager.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker);
+ when(blockManager.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker);
+ byte[] openBlocksMessage = JavaUtils.serialize(
+ new OpenShuffleBlocks("app0", "exec1", new String[] { "b0", "b1" }));
+ handler.receive(client, openBlocksMessage, callback);
+ verify(blockManager, times(1)).getBlockData("app0", "exec1", "b0");
+ verify(blockManager, times(1)).getBlockData("app0", "exec1", "b1");
+
+ ArgumentCaptor<byte[]> response = ArgumentCaptor.forClass(byte[].class);
+ verify(callback, times(1)).onSuccess(response.capture());
+ verify(callback, never()).onFailure((Throwable) any());
+
+ ShuffleStreamHandle handle = JavaUtils.deserialize(response.getValue());
+ assertEquals(2, handle.numChunks);
+
+ ArgumentCaptor<Iterator> stream = ArgumentCaptor.forClass(Iterator.class);
+ verify(streamManager, times(1)).registerStream(stream.capture());
+ Iterator<ManagedBuffer> buffers = (Iterator<ManagedBuffer>) stream.getValue();
+ assertEquals(block0Marker, buffers.next());
+ assertEquals(block1Marker, buffers.next());
+ assertFalse(buffers.hasNext());
+ }
+
+ @Test
+ public void testBadMessages() {
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+
+ byte[] unserializableMessage = new byte[] { 0x12, 0x34, 0x56 };
+ try {
+ handler.receive(client, unserializableMessage, callback);
+ fail("Should have thrown");
+ } catch (Exception e) {
+ // pass
+ }
+
+ byte[] unexpectedMessage = JavaUtils.serialize(
+ new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"));
+ try {
+ handler.receive(client, unexpectedMessage, callback);
+ fail("Should have thrown");
+ } catch (UnsupportedOperationException e) {
+ // pass
+ }
+
+ verify(callback, never()).onSuccess((byte[]) any());
+ verify(callback, never()).onFailure((Throwable) any());
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java
new file mode 100644
index 0000000000..da54797e89
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+
+import com.google.common.io.CharStreams;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+public class ExternalShuffleBlockManagerSuite {
+ static String sortBlock0 = "Hello!";
+ static String sortBlock1 = "World!";
+
+ static String hashBlock0 = "Elementary";
+ static String hashBlock1 = "Tabular";
+
+ static TestShuffleDataContext dataContext;
+
+ @BeforeClass
+ public static void beforeAll() throws IOException {
+ dataContext = new TestShuffleDataContext(2, 5);
+
+ dataContext.create();
+ // Write some sort and hash data.
+ dataContext.insertSortShuffleData(0, 0,
+ new byte[][] { sortBlock0.getBytes(), sortBlock1.getBytes() } );
+ dataContext.insertHashShuffleData(1, 0,
+ new byte[][] { hashBlock0.getBytes(), hashBlock1.getBytes() } );
+ }
+
+ @AfterClass
+ public static void afterAll() {
+ dataContext.cleanup();
+ }
+
+ @Test
+ public void testBadRequests() {
+ ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager();
+ // Unregistered executor
+ try {
+ manager.getBlockData("app0", "exec1", "shuffle_1_1_0");
+ fail("Should have failed");
+ } catch (RuntimeException e) {
+ assertTrue("Bad error message: " + e, e.getMessage().contains("not registered"));
+ }
+
+ // Invalid shuffle manager
+ manager.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar"));
+ try {
+ manager.getBlockData("app0", "exec2", "shuffle_1_1_0");
+ fail("Should have failed");
+ } catch (UnsupportedOperationException e) {
+ // pass
+ }
+
+ // Nonexistent shuffle block
+ manager.registerExecutor("app0", "exec3",
+ dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager"));
+ try {
+ manager.getBlockData("app0", "exec3", "shuffle_1_1_0");
+ fail("Should have failed");
+ } catch (Exception e) {
+ // pass
+ }
+ }
+
+ @Test
+ public void testSortShuffleBlocks() throws IOException {
+ ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager();
+ manager.registerExecutor("app0", "exec0",
+ dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager"));
+
+ InputStream block0Stream =
+ manager.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream();
+ String block0 = CharStreams.toString(new InputStreamReader(block0Stream));
+ block0Stream.close();
+ assertEquals(sortBlock0, block0);
+
+ InputStream block1Stream =
+ manager.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream();
+ String block1 = CharStreams.toString(new InputStreamReader(block1Stream));
+ block1Stream.close();
+ assertEquals(sortBlock1, block1);
+ }
+
+ @Test
+ public void testHashShuffleBlocks() throws IOException {
+ ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager();
+ manager.registerExecutor("app0", "exec0",
+ dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager"));
+
+ InputStream block0Stream =
+ manager.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream();
+ String block0 = CharStreams.toString(new InputStreamReader(block0Stream));
+ block0Stream.close();
+ assertEquals(hashBlock0, block0);
+
+ InputStream block1Stream =
+ manager.getBlockData("app0", "exec0", "shuffle_1_0_1").createInputStream();
+ String block1 = CharStreams.toString(new InputStreamReader(block1Stream));
+ block1Stream.close();
+ assertEquals(hashBlock1, block1);
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
new file mode 100644
index 0000000000..b3bcf5fd68
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class ExternalShuffleIntegrationSuite {
+
+ static String APP_ID = "app-id";
+ static String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager";
+ static String HASH_MANAGER = "org.apache.spark.shuffle.hash.HashShuffleManager";
+
+ // Executor 0 is sort-based
+ static TestShuffleDataContext dataContext0;
+ // Executor 1 is hash-based
+ static TestShuffleDataContext dataContext1;
+
+ static ExternalShuffleBlockHandler handler;
+ static TransportServer server;
+ static TransportConf conf;
+
+ static byte[][] exec0Blocks = new byte[][] {
+ new byte[123],
+ new byte[12345],
+ new byte[1234567],
+ };
+
+ static byte[][] exec1Blocks = new byte[][] {
+ new byte[321],
+ new byte[54321],
+ };
+
+ @BeforeClass
+ public static void beforeAll() throws IOException {
+ Random rand = new Random();
+
+ for (byte[] block : exec0Blocks) {
+ rand.nextBytes(block);
+ }
+ for (byte[] block: exec1Blocks) {
+ rand.nextBytes(block);
+ }
+
+ dataContext0 = new TestShuffleDataContext(2, 5);
+ dataContext0.create();
+ dataContext0.insertSortShuffleData(0, 0, exec0Blocks);
+
+ dataContext1 = new TestShuffleDataContext(6, 2);
+ dataContext1.create();
+ dataContext1.insertHashShuffleData(1, 0, exec1Blocks);
+
+ conf = new TransportConf(new SystemPropertyConfigProvider());
+ handler = new ExternalShuffleBlockHandler();
+ TransportContext transportContext = new TransportContext(conf, handler);
+ server = transportContext.createServer();
+ }
+
+ @AfterClass
+ public static void afterAll() {
+ dataContext0.cleanup();
+ dataContext1.cleanup();
+ server.close();
+ }
+
+ @After
+ public void afterEach() {
+ handler.clearRegisteredExecutors();
+ }
+
+ class FetchResult {
+ public Set<String> successBlocks;
+ public Set<String> failedBlocks;
+ public List<ManagedBuffer> buffers;
+
+ public void releaseBuffers() {
+ for (ManagedBuffer buffer : buffers) {
+ buffer.release();
+ }
+ }
+ }
+
+ // Fetch a set of blocks from a pre-registered executor.
+ private FetchResult fetchBlocks(String execId, String[] blockIds) throws Exception {
+ return fetchBlocks(execId, blockIds, server.getPort());
+ }
+
+ // Fetch a set of blocks from a pre-registered executor. Connects to the server on the given port,
+ // to allow connecting to invalid servers.
+ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) throws Exception {
+ final FetchResult res = new FetchResult();
+ res.successBlocks = Collections.synchronizedSet(new HashSet<String>());
+ res.failedBlocks = Collections.synchronizedSet(new HashSet<String>());
+ res.buffers = Collections.synchronizedList(new LinkedList<ManagedBuffer>());
+
+ final Semaphore requestsRemaining = new Semaphore(0);
+
+ ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID);
+ client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
+ new BlockFetchingListener() {
+ @Override
+ public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
+ synchronized (this) {
+ if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) {
+ data.retain();
+ res.successBlocks.add(blockId);
+ res.buffers.add(data);
+ requestsRemaining.release();
+ }
+ }
+ }
+
+ @Override
+ public void onBlockFetchFailure(String blockId, Throwable exception) {
+ synchronized (this) {
+ if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) {
+ res.failedBlocks.add(blockId);
+ requestsRemaining.release();
+ }
+ }
+ }
+ });
+
+ if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server");
+ }
+ return res;
+ }
+
+ @Test
+ public void testFetchOneSort() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult exec0Fetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" });
+ assertEquals(Sets.newHashSet("shuffle_0_0_0"), exec0Fetch.successBlocks);
+ assertTrue(exec0Fetch.failedBlocks.isEmpty());
+ assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks[0]));
+ exec0Fetch.releaseBuffers();
+ }
+
+ @Test
+ public void testFetchThreeSort() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult exec0Fetch = fetchBlocks("exec-0",
+ new String[] { "shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2" });
+ assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2"),
+ exec0Fetch.successBlocks);
+ assertTrue(exec0Fetch.failedBlocks.isEmpty());
+ assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks));
+ exec0Fetch.releaseBuffers();
+ }
+
+ @Test
+ public void testFetchHash() throws Exception {
+ registerExecutor("exec-1", dataContext1.createExecutorInfo(HASH_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-1",
+ new String[] { "shuffle_1_0_0", "shuffle_1_0_1" });
+ assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.successBlocks);
+ assertTrue(execFetch.failedBlocks.isEmpty());
+ assertBufferListsEqual(execFetch.buffers, Lists.newArrayList(exec1Blocks));
+ execFetch.releaseBuffers();
+ }
+
+ @Test
+ public void testFetchWrongShuffle() throws Exception {
+ registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */));
+ FetchResult execFetch = fetchBlocks("exec-1",
+ new String[] { "shuffle_1_0_0", "shuffle_1_0_1" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchInvalidShuffle() throws Exception {
+ registerExecutor("exec-1", dataContext1.createExecutorInfo("unknown sort manager"));
+ FetchResult execFetch = fetchBlocks("exec-1",
+ new String[] { "shuffle_1_0_0" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchWrongBlockId() throws Exception {
+ registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */));
+ FetchResult execFetch = fetchBlocks("exec-1",
+ new String[] { "rdd_1_0_0" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("rdd_1_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchNonexistent() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-0",
+ new String[] { "shuffle_2_0_0" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_2_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchWrongExecutor() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-0",
+ new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ });
+ // Both still fail, as we start by checking for all block.
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchUnregisteredExecutor() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-2",
+ new String[] { "shuffle_0_0_0", "shuffle_1_0_0" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchNoServer() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-0",
+ new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }, 1 /* port */);
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks);
+ }
+
+ private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) {
+ ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID);
+ client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
+ executorId, executorInfo);
+ }
+
+ private void assertBufferListsEqual(List<ManagedBuffer> list0, List<byte[]> list1)
+ throws Exception {
+ assertEquals(list0.size(), list1.size());
+ for (int i = 0; i < list0.size(); i ++) {
+ assertBuffersEqual(list0.get(i), new NioManagedBuffer(ByteBuffer.wrap(list1.get(i))));
+ }
+ }
+
+ private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception {
+ ByteBuffer nio0 = buffer0.nioByteBuffer();
+ ByteBuffer nio1 = buffer1.nioByteBuffer();
+
+ int len = nio0.remaining();
+ assertEquals(nio0.remaining(), nio1.remaining());
+ for (int i = 0; i < len; i ++) {
+ assertEquals(nio0.get(), nio1.get());
+ }
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
new file mode 100644
index 0000000000..c18346f696
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.nio.ByteBuffer;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.Maps;
+import io.netty.buffer.Unpooled;
+import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.util.JavaUtils;
+
+public class OneForOneBlockFetcherSuite {
+ @Test
+ public void testFetchOne() {
+ LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+ blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
+
+ BlockFetchingListener listener = fetchBlocks(blocks);
+
+ verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0"));
+ }
+
+ @Test
+ public void testFetchThree() {
+ LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+ blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+ blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
+ blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
+
+ BlockFetchingListener listener = fetchBlocks(blocks);
+
+ for (int i = 0; i < 3; i ++) {
+ verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i));
+ }
+ }
+
+ @Test
+ public void testFailure() {
+ LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+ blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+ blocks.put("b1", null);
+ blocks.put("b2", null);
+
+ BlockFetchingListener listener = fetchBlocks(blocks);
+
+ // Each failure will cause a failure to be invoked in all remaining block fetches.
+ verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0"));
+ verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any());
+ verify(listener, times(2)).onBlockFetchFailure(eq("b2"), (Throwable) any());
+ }
+
+ @Test
+ public void testFailureAndSuccess() {
+ LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+ blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+ blocks.put("b1", null);
+ blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21])));
+
+ BlockFetchingListener listener = fetchBlocks(blocks);
+
+ // We may call both success and failure for the same block.
+ verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0"));
+ verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any());
+ verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2"));
+ verify(listener, times(1)).onBlockFetchFailure(eq("b2"), (Throwable) any());
+ }
+
+ @Test
+ public void testEmptyBlockFetch() {
+ try {
+ fetchBlocks(Maps.<String, ManagedBuffer>newLinkedHashMap());
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertEquals("Zero-sized blockIds array", e.getMessage());
+ }
+ }
+
+ /**
+ * Begins a fetch on the given set of blocks by mocking out the server side of the RPC which
+ * simply returns the given (BlockId, Block) pairs.
+ * As "blocks" is a LinkedHashMap, the blocks are guaranteed to be returned in the same order
+ * that they were inserted in.
+ *
+ * If a block's buffer is "null", an exception will be thrown instead.
+ */
+ private BlockFetchingListener fetchBlocks(final LinkedHashMap<String, ManagedBuffer> blocks) {
+ TransportClient client = mock(TransportClient.class);
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+ String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
+ OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client, blockIds, listener);
+
+ // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+ String message = JavaUtils.deserialize((byte[]) invocationOnMock.getArguments()[0]);
+ RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1];
+ callback.onSuccess(JavaUtils.serialize(new ShuffleStreamHandle(123, blocks.size())));
+ assertEquals("OpenZeBlocks", message);
+ return null;
+ }
+ }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any());
+
+ // Respond to each chunk request with a single buffer from our blocks array.
+ final AtomicInteger expectedChunkIndex = new AtomicInteger(0);
+ final Iterator<ManagedBuffer> blockIterator = blocks.values().iterator();
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocation) throws Throwable {
+ try {
+ long streamId = (Long) invocation.getArguments()[0];
+ int myChunkIndex = (Integer) invocation.getArguments()[1];
+ assertEquals(123, streamId);
+ assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex);
+
+ ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2];
+ ManagedBuffer result = blockIterator.next();
+ if (result != null) {
+ callback.onSuccess(myChunkIndex, result);
+ } else {
+ callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex));
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ fail("Unexpected failure");
+ }
+ return null;
+ }
+ }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any());
+
+ fetcher.start("OpenZeBlocks");
+ return listener;
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java
new file mode 100644
index 0000000000..ee9482b49c
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.util.JavaUtils;
+
+import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*;
+
+public class ShuffleMessagesSuite {
+ @Test
+ public void serializeOpenShuffleBlocks() {
+ OpenShuffleBlocks msg = new OpenShuffleBlocks("app-1", "exec-2",
+ new String[] { "block0", "block1" });
+ OpenShuffleBlocks msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
+ assertEquals(msg, msg2);
+ }
+
+ @Test
+ public void serializeRegisterExecutor() {
+ RegisterExecutor msg = new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo(
+ new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"));
+ RegisterExecutor msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
+ assertEquals(msg, msg2);
+ }
+
+ @Test
+ public void serializeShuffleStreamHandle() {
+ ShuffleStreamHandle msg = new ShuffleStreamHandle(12345, 16);
+ ShuffleStreamHandle msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
+ assertEquals(msg, msg2);
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
new file mode 100644
index 0000000000..442b756467
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+
+import com.google.common.io.Files;
+
+/**
+ * Manages some sort- and hash-based shuffle data, including the creation
+ * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}.
+ */
+public class TestShuffleDataContext {
+ private final String[] localDirs;
+ private final int subDirsPerLocalDir;
+
+ public TestShuffleDataContext(int numLocalDirs, int subDirsPerLocalDir) {
+ this.localDirs = new String[numLocalDirs];
+ this.subDirsPerLocalDir = subDirsPerLocalDir;
+ }
+
+ public void create() {
+ for (int i = 0; i < localDirs.length; i ++) {
+ localDirs[i] = Files.createTempDir().getAbsolutePath();
+
+ for (int p = 0; p < subDirsPerLocalDir; p ++) {
+ new File(localDirs[i], String.format("%02x", p)).mkdirs();
+ }
+ }
+ }
+
+ public void cleanup() {
+ for (String localDir : localDirs) {
+ deleteRecursively(new File(localDir));
+ }
+ }
+
+ /** Creates reducer blocks in a sort-based data format within our local dirs. */
+ public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException {
+ String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0";
+
+ OutputStream dataStream = new FileOutputStream(
+ ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId + ".data"));
+ DataOutputStream indexStream = new DataOutputStream(new FileOutputStream(
+ ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId + ".index")));
+
+ long offset = 0;
+ indexStream.writeLong(offset);
+ for (byte[] block : blocks) {
+ offset += block.length;
+ dataStream.write(block);
+ indexStream.writeLong(offset);
+ }
+
+ dataStream.close();
+ indexStream.close();
+ }
+
+ /** Creates reducer blocks in a hash-based data format within our local dirs. */
+ public void insertHashShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException {
+ for (int i = 0; i < blocks.length; i ++) {
+ String blockId = "shuffle_" + shuffleId + "_" + mapId + "_" + i;
+ Files.write(blocks[i],
+ ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId));
+ }
+ }
+
+ /**
+ * Creates an ExecutorShuffleInfo object based on the given shuffle manager which targets this
+ * context's directories.
+ */
+ public ExecutorShuffleInfo createExecutorInfo(String shuffleManager) {
+ return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager);
+ }
+
+ private static void deleteRecursively(File f) {
+ assert f != null;
+ if (f.isDirectory()) {
+ File[] children = f.listFiles();
+ if (children != null) {
+ for (File child : children) {
+ deleteRecursively(child);
+ }
+ }
+ }
+ f.delete();
+ }
+}