aboutsummaryrefslogtreecommitdiff
path: root/common/network-shuffle
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-02-28 17:25:07 -0800
committerReynold Xin <rxin@databricks.com>2016-02-28 17:25:07 -0800
commit9e01dcc6446f8648e61062f8afe62589b9d4b5ab (patch)
treeae3c7015e950de315a490ce58c181671bfd12907 /common/network-shuffle
parentcca79fad66c4315b0ed6de59fd87700a540e6646 (diff)
downloadspark-9e01dcc6446f8648e61062f8afe62589b9d4b5ab.tar.gz
spark-9e01dcc6446f8648e61062f8afe62589b9d4b5ab.tar.bz2
spark-9e01dcc6446f8648e61062f8afe62589b9d4b5ab.zip
[SPARK-13529][BUILD] Move network/* modules into common/network-*
## What changes were proposed in this pull request? As the title says, this moves the three modules currently in network/ into common/network-*. This removes one top level, non-user-facing folder. ## How was this patch tested? Compilation and existing tests. We should run both SBT and Maven. Author: Reynold Xin <rxin@databricks.com> Closes #11409 from rxin/SPARK-13529.
Diffstat (limited to 'common/network-shuffle')
-rw-r--r--common/network-shuffle/pom.xml101
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java97
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java36
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java140
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java449
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java154
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java129
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java234
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java44
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java73
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java81
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java94
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java90
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java94
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java81
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java117
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java63
-rw-r--r--common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java294
-rw-r--r--common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java44
-rw-r--r--common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java127
-rw-r--r--common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java156
-rw-r--r--common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java149
-rw-r--r--common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java301
-rw-r--r--common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java124
-rw-r--r--common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java176
-rw-r--r--common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java313
-rw-r--r--common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java117
27 files changed, 3878 insertions, 0 deletions
diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml
new file mode 100644
index 0000000000..810ec10ca0
--- /dev/null
+++ b/common/network-shuffle/pom.xml
@@ -0,0 +1,101 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent_2.11</artifactId>
+ <version>2.0.0-SNAPSHOT</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-network-shuffle_2.11</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project Shuffle Streaming Service</name>
+ <url>http://spark.apache.org/</url>
+ <properties>
+ <sbt.project.name>network-shuffle</sbt.project.name>
+ </properties>
+
+ <dependencies>
+ <!-- Core dependencies -->
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-network-common_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+
+ <dependency>
+ <groupId>org.fusesource.leveldbjni</groupId>
+ <artifactId>leveldbjni-all</artifactId>
+ <version>1.8</version>
+ </dependency>
+
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-annotations</artifactId>
+ </dependency>
+
+ <!-- Provided dependencies -->
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ </dependency>
+
+ <!-- Test dependencies -->
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-network-common_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-test-tags_${scala.binary.version}</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+ </build>
+</project>
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java
new file mode 100644
index 0000000000..351c7930a9
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.lang.Override;
+import java.nio.ByteBuffer;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.util.JavaUtils;
+
+/**
+ * A class that manages shuffle secret used by the external shuffle service.
+ */
+public class ShuffleSecretManager implements SecretKeyHolder {
+ private final Logger logger = LoggerFactory.getLogger(ShuffleSecretManager.class);
+ private final ConcurrentHashMap<String, String> shuffleSecretMap;
+
+ // Spark user used for authenticating SASL connections
+ // Note that this must match the value in org.apache.spark.SecurityManager
+ private static final String SPARK_SASL_USER = "sparkSaslUser";
+
+ public ShuffleSecretManager() {
+ shuffleSecretMap = new ConcurrentHashMap<String, String>();
+ }
+
+ /**
+ * Register an application with its secret.
+ * Executors need to first authenticate themselves with the same secret before
+ * fetching shuffle files written by other executors in this application.
+ */
+ public void registerApp(String appId, String shuffleSecret) {
+ if (!shuffleSecretMap.contains(appId)) {
+ shuffleSecretMap.put(appId, shuffleSecret);
+ logger.info("Registered shuffle secret for application {}", appId);
+ } else {
+ logger.debug("Application {} already registered", appId);
+ }
+ }
+
+ /**
+ * Register an application with its secret specified as a byte buffer.
+ */
+ public void registerApp(String appId, ByteBuffer shuffleSecret) {
+ registerApp(appId, JavaUtils.bytesToString(shuffleSecret));
+ }
+
+ /**
+ * Unregister an application along with its secret.
+ * This is called when the application terminates.
+ */
+ public void unregisterApp(String appId) {
+ if (shuffleSecretMap.contains(appId)) {
+ shuffleSecretMap.remove(appId);
+ logger.info("Unregistered shuffle secret for application {}", appId);
+ } else {
+ logger.warn("Attempted to unregister application {} when it is not registered", appId);
+ }
+ }
+
+ /**
+ * Return the Spark user for authenticating SASL connections.
+ */
+ @Override
+ public String getSaslUser(String appId) {
+ return SPARK_SASL_USER;
+ }
+
+ /**
+ * Return the secret key registered with the given application.
+ * This key is used to authenticate the executors before they can fetch shuffle files
+ * written by this application from the external shuffle service. If the specified
+ * application is not registered, return null.
+ */
+ @Override
+ public String getSecretKey(String appId) {
+ return shuffleSecretMap.get(appId);
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java
new file mode 100644
index 0000000000..138fd5389c
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.util.EventListener;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+public interface BlockFetchingListener extends EventListener {
+ /**
+ * Called once per successfully fetched block. After this call returns, data will be released
+ * automatically. If the data will be passed to another thread, the receiver should retain()
+ * and release() the buffer on their own, or copy the data to a new buffer.
+ */
+ void onBlockFetchSuccess(String blockId, ManagedBuffer data);
+
+ /**
+ * Called at least once per block upon failures.
+ */
+ void onBlockFetchFailure(String blockId, Throwable exception);
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
new file mode 100644
index 0000000000..f22187a01d
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.List;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Lists;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId;
+import org.apache.spark.network.shuffle.protocol.*;
+import org.apache.spark.network.util.TransportConf;
+
+
+/**
+ * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process.
+ *
+ * Handles registering executors and opening shuffle blocks from them. Shuffle blocks are registered
+ * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark-
+ * level shuffle block.
+ */
+public class ExternalShuffleBlockHandler extends RpcHandler {
+ private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class);
+
+ @VisibleForTesting
+ final ExternalShuffleBlockResolver blockManager;
+ private final OneForOneStreamManager streamManager;
+
+ public ExternalShuffleBlockHandler(TransportConf conf, File registeredExecutorFile) throws IOException {
+ this(new OneForOneStreamManager(),
+ new ExternalShuffleBlockResolver(conf, registeredExecutorFile));
+ }
+
+ /** Enables mocking out the StreamManager and BlockManager. */
+ @VisibleForTesting
+ public ExternalShuffleBlockHandler(
+ OneForOneStreamManager streamManager,
+ ExternalShuffleBlockResolver blockManager) {
+ this.streamManager = streamManager;
+ this.blockManager = blockManager;
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
+ BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message);
+ handleMessage(msgObj, client, callback);
+ }
+
+ protected void handleMessage(
+ BlockTransferMessage msgObj,
+ TransportClient client,
+ RpcResponseCallback callback) {
+ if (msgObj instanceof OpenBlocks) {
+ OpenBlocks msg = (OpenBlocks) msgObj;
+ checkAuth(client, msg.appId);
+
+ List<ManagedBuffer> blocks = Lists.newArrayList();
+ for (String blockId : msg.blockIds) {
+ blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId));
+ }
+ long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator());
+ logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
+ callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer());
+
+ } else if (msgObj instanceof RegisterExecutor) {
+ RegisterExecutor msg = (RegisterExecutor) msgObj;
+ checkAuth(client, msg.appId);
+ blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
+ callback.onSuccess(ByteBuffer.wrap(new byte[0]));
+
+ } else {
+ throw new UnsupportedOperationException("Unexpected message: " + msgObj);
+ }
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return streamManager;
+ }
+
+ /**
+ * Removes an application (once it has been terminated), and optionally will clean up any
+ * local directories associated with the executors of that application in a separate thread.
+ */
+ public void applicationRemoved(String appId, boolean cleanupLocalDirs) {
+ blockManager.applicationRemoved(appId, cleanupLocalDirs);
+ }
+
+ /**
+ * Register an (application, executor) with the given shuffle info.
+ *
+ * The "re-" is meant to highlight the intended use of this method -- when this service is
+ * restarted, this is used to restore the state of executors from before the restart. Normal
+ * registration will happen via a message handled in receive()
+ *
+ * @param appExecId
+ * @param executorInfo
+ */
+ public void reregisterExecutor(AppExecId appExecId, ExecutorShuffleInfo executorInfo) {
+ blockManager.registerExecutor(appExecId.appId, appExecId.execId, executorInfo);
+ }
+
+ public void close() {
+ blockManager.close();
+ }
+
+ private void checkAuth(TransportClient client, String appId) {
+ if (client.getClientId() != null && !client.getClientId().equals(appId)) {
+ throw new SecurityException(String.format(
+ "Client for %s not authorized for application %s.", client.getClientId(), appId));
+ }
+ }
+
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
new file mode 100644
index 0000000000..fe933ed650
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
@@ -0,0 +1,449 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.*;
+import java.util.*;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Charsets;
+import com.google.common.base.Objects;
+import com.google.common.collect.Maps;
+import org.fusesource.leveldbjni.JniDBFactory;
+import org.fusesource.leveldbjni.internal.NativeDB;
+import org.iq80.leveldb.DB;
+import org.iq80.leveldb.DBIterator;
+import org.iq80.leveldb.Options;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Manages converting shuffle BlockIds into physical segments of local files, from a process outside
+ * of Executors. Each Executor must register its own configuration about where it stores its files
+ * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated
+ * from Spark's FileShuffleBlockResolver and IndexShuffleBlockResolver.
+ */
+public class ExternalShuffleBlockResolver {
+ private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class);
+
+ private static final ObjectMapper mapper = new ObjectMapper();
+ /**
+ * This a common prefix to the key for each app registration we stick in leveldb, so they
+ * are easy to find, since leveldb lets you search based on prefix.
+ */
+ private static final String APP_KEY_PREFIX = "AppExecShuffleInfo";
+ private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0);
+
+ // Map containing all registered executors' metadata.
+ @VisibleForTesting
+ final ConcurrentMap<AppExecId, ExecutorShuffleInfo> executors;
+
+ // Single-threaded Java executor used to perform expensive recursive directory deletion.
+ private final Executor directoryCleaner;
+
+ private final TransportConf conf;
+
+ @VisibleForTesting
+ final File registeredExecutorFile;
+ @VisibleForTesting
+ final DB db;
+
+ public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile)
+ throws IOException {
+ this(conf, registeredExecutorFile, Executors.newSingleThreadExecutor(
+ // Add `spark` prefix because it will run in NM in Yarn mode.
+ NettyUtils.createThreadFactory("spark-shuffle-directory-cleaner")));
+ }
+
+ // Allows tests to have more control over when directories are cleaned up.
+ @VisibleForTesting
+ ExternalShuffleBlockResolver(
+ TransportConf conf,
+ File registeredExecutorFile,
+ Executor directoryCleaner) throws IOException {
+ this.conf = conf;
+ this.registeredExecutorFile = registeredExecutorFile;
+ if (registeredExecutorFile != null) {
+ Options options = new Options();
+ options.createIfMissing(false);
+ options.logger(new LevelDBLogger());
+ DB tmpDb;
+ try {
+ tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options);
+ } catch (NativeDB.DBException e) {
+ if (e.isNotFound() || e.getMessage().contains(" does not exist ")) {
+ logger.info("Creating state database at " + registeredExecutorFile);
+ options.createIfMissing(true);
+ try {
+ tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options);
+ } catch (NativeDB.DBException dbExc) {
+ throw new IOException("Unable to create state store", dbExc);
+ }
+ } else {
+ // the leveldb file seems to be corrupt somehow. Lets just blow it away and create a new
+ // one, so we can keep processing new apps
+ logger.error("error opening leveldb file {}. Creating new file, will not be able to " +
+ "recover state for existing applications", registeredExecutorFile, e);
+ if (registeredExecutorFile.isDirectory()) {
+ for (File f : registeredExecutorFile.listFiles()) {
+ if (!f.delete()) {
+ logger.warn("error deleting {}", f.getPath());
+ }
+ }
+ }
+ if (!registeredExecutorFile.delete()) {
+ logger.warn("error deleting {}", registeredExecutorFile.getPath());
+ }
+ options.createIfMissing(true);
+ try {
+ tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options);
+ } catch (NativeDB.DBException dbExc) {
+ throw new IOException("Unable to create state store", dbExc);
+ }
+
+ }
+ }
+ // if there is a version mismatch, we throw an exception, which means the service is unusable
+ checkVersion(tmpDb);
+ executors = reloadRegisteredExecutors(tmpDb);
+ db = tmpDb;
+ } else {
+ db = null;
+ executors = Maps.newConcurrentMap();
+ }
+ this.directoryCleaner = directoryCleaner;
+ }
+
+ /** Registers a new Executor with all the configuration we need to find its shuffle files. */
+ public void registerExecutor(
+ String appId,
+ String execId,
+ ExecutorShuffleInfo executorInfo) {
+ AppExecId fullId = new AppExecId(appId, execId);
+ logger.info("Registered executor {} with {}", fullId, executorInfo);
+ try {
+ if (db != null) {
+ byte[] key = dbAppExecKey(fullId);
+ byte[] value = mapper.writeValueAsString(executorInfo).getBytes(Charsets.UTF_8);
+ db.put(key, value);
+ }
+ } catch (Exception e) {
+ logger.error("Error saving registered executors", e);
+ }
+ executors.put(fullId, executorInfo);
+ }
+
+ /**
+ * Obtains a FileSegmentManagedBuffer from a shuffle block id. We expect the blockId has the
+ * format "shuffle_ShuffleId_MapId_ReduceId" (from ShuffleBlockId), and additionally make
+ * assumptions about how the hash and sort based shuffles store their data.
+ */
+ public ManagedBuffer getBlockData(String appId, String execId, String blockId) {
+ String[] blockIdParts = blockId.split("_");
+ if (blockIdParts.length < 4) {
+ throw new IllegalArgumentException("Unexpected block id format: " + blockId);
+ } else if (!blockIdParts[0].equals("shuffle")) {
+ throw new IllegalArgumentException("Expected shuffle block id, got: " + blockId);
+ }
+ int shuffleId = Integer.parseInt(blockIdParts[1]);
+ int mapId = Integer.parseInt(blockIdParts[2]);
+ int reduceId = Integer.parseInt(blockIdParts[3]);
+
+ ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId));
+ if (executor == null) {
+ throw new RuntimeException(
+ String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId));
+ }
+
+ if ("sort".equals(executor.shuffleManager) || "tungsten-sort".equals(executor.shuffleManager)) {
+ return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId);
+ } else if ("hash".equals(executor.shuffleManager)) {
+ return getHashBasedShuffleBlockData(executor, blockId);
+ } else {
+ throw new UnsupportedOperationException(
+ "Unsupported shuffle manager: " + executor.shuffleManager);
+ }
+ }
+
+ /**
+ * Removes our metadata of all executors registered for the given application, and optionally
+ * also deletes the local directories associated with the executors of that application in a
+ * separate thread.
+ *
+ * It is not valid to call registerExecutor() for an executor with this appId after invoking
+ * this method.
+ */
+ public void applicationRemoved(String appId, boolean cleanupLocalDirs) {
+ logger.info("Application {} removed, cleanupLocalDirs = {}", appId, cleanupLocalDirs);
+ Iterator<Map.Entry<AppExecId, ExecutorShuffleInfo>> it = executors.entrySet().iterator();
+ while (it.hasNext()) {
+ Map.Entry<AppExecId, ExecutorShuffleInfo> entry = it.next();
+ AppExecId fullId = entry.getKey();
+ final ExecutorShuffleInfo executor = entry.getValue();
+
+ // Only touch executors associated with the appId that was removed.
+ if (appId.equals(fullId.appId)) {
+ it.remove();
+ if (db != null) {
+ try {
+ db.delete(dbAppExecKey(fullId));
+ } catch (IOException e) {
+ logger.error("Error deleting {} from executor state db", appId, e);
+ }
+ }
+
+ if (cleanupLocalDirs) {
+ logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length);
+
+ // Execute the actual deletion in a different thread, as it may take some time.
+ directoryCleaner.execute(new Runnable() {
+ @Override
+ public void run() {
+ deleteExecutorDirs(executor.localDirs);
+ }
+ });
+ }
+ }
+ }
+ }
+
+ /**
+ * Synchronously deletes each directory one at a time.
+ * Should be executed in its own thread, as this may take a long time.
+ */
+ private void deleteExecutorDirs(String[] dirs) {
+ for (String localDir : dirs) {
+ try {
+ JavaUtils.deleteRecursively(new File(localDir));
+ logger.debug("Successfully cleaned up directory: " + localDir);
+ } catch (Exception e) {
+ logger.error("Failed to delete directory: " + localDir, e);
+ }
+ }
+ }
+
+ /**
+ * Hash-based shuffle data is simply stored as one file per block.
+ * This logic is from FileShuffleBlockResolver.
+ */
+ private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) {
+ File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId);
+ return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length());
+ }
+
+ /**
+ * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file
+ * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockResolver,
+ * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId.
+ */
+ private ManagedBuffer getSortBasedShuffleBlockData(
+ ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) {
+ File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir,
+ "shuffle_" + shuffleId + "_" + mapId + "_0.index");
+
+ DataInputStream in = null;
+ try {
+ in = new DataInputStream(new FileInputStream(indexFile));
+ in.skipBytes(reduceId * 8);
+ long offset = in.readLong();
+ long nextOffset = in.readLong();
+ return new FileSegmentManagedBuffer(
+ conf,
+ getFile(executor.localDirs, executor.subDirsPerLocalDir,
+ "shuffle_" + shuffleId + "_" + mapId + "_0.data"),
+ offset,
+ nextOffset - offset);
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to open file: " + indexFile, e);
+ } finally {
+ if (in != null) {
+ JavaUtils.closeQuietly(in);
+ }
+ }
+ }
+
+ /**
+ * Hashes a filename into the corresponding local directory, in a manner consistent with
+ * Spark's DiskBlockManager.getFile().
+ */
+ @VisibleForTesting
+ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) {
+ int hash = JavaUtils.nonNegativeHash(filename);
+ String localDir = localDirs[hash % localDirs.length];
+ int subDirId = (hash / localDirs.length) % subDirsPerLocalDir;
+ return new File(new File(localDir, String.format("%02x", subDirId)), filename);
+ }
+
+ void close() {
+ if (db != null) {
+ try {
+ db.close();
+ } catch (IOException e) {
+ logger.error("Exception closing leveldb with registered executors", e);
+ }
+ }
+ }
+
+ /** Simply encodes an executor's full ID, which is appId + execId. */
+ public static class AppExecId {
+ public final String appId;
+ public final String execId;
+
+ @JsonCreator
+ public AppExecId(@JsonProperty("appId") String appId, @JsonProperty("execId") String execId) {
+ this.appId = appId;
+ this.execId = execId;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ AppExecId appExecId = (AppExecId) o;
+ return Objects.equal(appId, appExecId.appId) && Objects.equal(execId, appExecId.execId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, execId);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("execId", execId)
+ .toString();
+ }
+ }
+
+ private static byte[] dbAppExecKey(AppExecId appExecId) throws IOException {
+ // we stick a common prefix on all the keys so we can find them in the DB
+ String appExecJson = mapper.writeValueAsString(appExecId);
+ String key = (APP_KEY_PREFIX + ";" + appExecJson);
+ return key.getBytes(Charsets.UTF_8);
+ }
+
+ private static AppExecId parseDbAppExecKey(String s) throws IOException {
+ if (!s.startsWith(APP_KEY_PREFIX)) {
+ throw new IllegalArgumentException("expected a string starting with " + APP_KEY_PREFIX);
+ }
+ String json = s.substring(APP_KEY_PREFIX.length() + 1);
+ AppExecId parsed = mapper.readValue(json, AppExecId.class);
+ return parsed;
+ }
+
+ @VisibleForTesting
+ static ConcurrentMap<AppExecId, ExecutorShuffleInfo> reloadRegisteredExecutors(DB db)
+ throws IOException {
+ ConcurrentMap<AppExecId, ExecutorShuffleInfo> registeredExecutors = Maps.newConcurrentMap();
+ if (db != null) {
+ DBIterator itr = db.iterator();
+ itr.seek(APP_KEY_PREFIX.getBytes(Charsets.UTF_8));
+ while (itr.hasNext()) {
+ Map.Entry<byte[], byte[]> e = itr.next();
+ String key = new String(e.getKey(), Charsets.UTF_8);
+ if (!key.startsWith(APP_KEY_PREFIX)) {
+ break;
+ }
+ AppExecId id = parseDbAppExecKey(key);
+ ExecutorShuffleInfo shuffleInfo = mapper.readValue(e.getValue(), ExecutorShuffleInfo.class);
+ registeredExecutors.put(id, shuffleInfo);
+ }
+ }
+ return registeredExecutors;
+ }
+
+ private static class LevelDBLogger implements org.iq80.leveldb.Logger {
+ private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class);
+
+ @Override
+ public void log(String message) {
+ LOG.info(message);
+ }
+ }
+
+ /**
+ * Simple major.minor versioning scheme. Any incompatible changes should be across major
+ * versions. Minor version differences are allowed -- meaning we should be able to read
+ * dbs that are either earlier *or* later on the minor version.
+ */
+ private static void checkVersion(DB db) throws IOException {
+ byte[] bytes = db.get(StoreVersion.KEY);
+ if (bytes == null) {
+ storeVersion(db);
+ } else {
+ StoreVersion version = mapper.readValue(bytes, StoreVersion.class);
+ if (version.major != CURRENT_VERSION.major) {
+ throw new IOException("cannot read state DB with version " + version + ", incompatible " +
+ "with current version " + CURRENT_VERSION);
+ }
+ storeVersion(db);
+ }
+ }
+
+ private static void storeVersion(DB db) throws IOException {
+ db.put(StoreVersion.KEY, mapper.writeValueAsBytes(CURRENT_VERSION));
+ }
+
+
+ public static class StoreVersion {
+
+ static final byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8);
+
+ public final int major;
+ public final int minor;
+
+ @JsonCreator public StoreVersion(@JsonProperty("major") int major, @JsonProperty("minor") int minor) {
+ this.major = major;
+ this.minor = minor;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ StoreVersion that = (StoreVersion) o;
+
+ return major == that.major && minor == that.minor;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = major;
+ result = 31 * result + minor;
+ return result;
+ }
+ }
+
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
new file mode 100644
index 0000000000..58ca87d9d3
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.List;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.sasl.SaslClientBootstrap;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.server.NoOpRpcHandler;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Client for reading shuffle blocks which points to an external (outside of executor) server.
+ * This is instead of reading shuffle blocks directly from other executors (via
+ * BlockTransferService), which has the downside of losing the shuffle data if we lose the
+ * executors.
+ */
+public class ExternalShuffleClient extends ShuffleClient {
+ private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class);
+
+ private final TransportConf conf;
+ private final boolean saslEnabled;
+ private final boolean saslEncryptionEnabled;
+ private final SecretKeyHolder secretKeyHolder;
+
+ protected TransportClientFactory clientFactory;
+ protected String appId;
+
+ /**
+ * Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled,
+ * then secretKeyHolder may be null.
+ */
+ public ExternalShuffleClient(
+ TransportConf conf,
+ SecretKeyHolder secretKeyHolder,
+ boolean saslEnabled,
+ boolean saslEncryptionEnabled) {
+ Preconditions.checkArgument(
+ !saslEncryptionEnabled || saslEnabled,
+ "SASL encryption can only be enabled if SASL is also enabled.");
+ this.conf = conf;
+ this.secretKeyHolder = secretKeyHolder;
+ this.saslEnabled = saslEnabled;
+ this.saslEncryptionEnabled = saslEncryptionEnabled;
+ }
+
+ protected void checkInit() {
+ assert appId != null : "Called before init()";
+ }
+
+ @Override
+ public void init(String appId) {
+ this.appId = appId;
+ TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true);
+ List<TransportClientBootstrap> bootstraps = Lists.newArrayList();
+ if (saslEnabled) {
+ bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled));
+ }
+ clientFactory = context.createClientFactory(bootstraps);
+ }
+
+ @Override
+ public void fetchBlocks(
+ final String host,
+ final int port,
+ final String execId,
+ String[] blockIds,
+ BlockFetchingListener listener) {
+ checkInit();
+ logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
+ try {
+ RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
+ new RetryingBlockFetcher.BlockFetchStarter() {
+ @Override
+ public void createAndStart(String[] blockIds, BlockFetchingListener listener)
+ throws IOException {
+ TransportClient client = clientFactory.createClient(host, port);
+ new OneForOneBlockFetcher(client, appId, execId, blockIds, listener).start();
+ }
+ };
+
+ int maxRetries = conf.maxIORetries();
+ if (maxRetries > 0) {
+ // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
+ // a bug in this code. We should remove the if statement once we're sure of the stability.
+ new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start();
+ } else {
+ blockFetchStarter.createAndStart(blockIds, listener);
+ }
+ } catch (Exception e) {
+ logger.error("Exception while beginning fetchBlocks", e);
+ for (String blockId : blockIds) {
+ listener.onBlockFetchFailure(blockId, e);
+ }
+ }
+ }
+
+ /**
+ * Registers this executor with an external shuffle server. This registration is required to
+ * inform the shuffle server about where and how we store our shuffle files.
+ *
+ * @param host Host of shuffle server.
+ * @param port Port of shuffle server.
+ * @param execId This Executor's id.
+ * @param executorInfo Contains all info necessary for the service to find our shuffle files.
+ */
+ public void registerWithShuffleServer(
+ String host,
+ int port,
+ String execId,
+ ExecutorShuffleInfo executorInfo) throws IOException {
+ checkInit();
+ TransportClient client = clientFactory.createUnmanagedClient(host, port);
+ try {
+ ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer();
+ client.sendRpcSync(registerMessage, 5000 /* timeoutMs */);
+ } finally {
+ client.close();
+ }
+ }
+
+ @Override
+ public void close() {
+ clientFactory.close();
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
new file mode 100644
index 0000000000..1b2ddbf1ed
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
+
+/**
+ * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and
+ * invokes the BlockFetchingListener appropriately. This class is agnostic to the actual RPC
+ * handler, as long as there is a single "open blocks" message which returns a ShuffleStreamHandle,
+ * and Java serialization is used.
+ *
+ * Note that this typically corresponds to a
+ * {@link org.apache.spark.network.server.OneForOneStreamManager} on the server side.
+ */
+public class OneForOneBlockFetcher {
+ private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);
+
+ private final TransportClient client;
+ private final OpenBlocks openMessage;
+ private final String[] blockIds;
+ private final BlockFetchingListener listener;
+ private final ChunkReceivedCallback chunkCallback;
+
+ private StreamHandle streamHandle = null;
+
+ public OneForOneBlockFetcher(
+ TransportClient client,
+ String appId,
+ String execId,
+ String[] blockIds,
+ BlockFetchingListener listener) {
+ this.client = client;
+ this.openMessage = new OpenBlocks(appId, execId, blockIds);
+ this.blockIds = blockIds;
+ this.listener = listener;
+ this.chunkCallback = new ChunkCallback();
+ }
+
+ /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
+ private class ChunkCallback implements ChunkReceivedCallback {
+ @Override
+ public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ // On receipt of a chunk, pass it upwards as a block.
+ listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
+ }
+
+ @Override
+ public void onFailure(int chunkIndex, Throwable e) {
+ // On receipt of a failure, fail every block from chunkIndex onwards.
+ String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length);
+ failRemainingBlocks(remainingBlockIds, e);
+ }
+ }
+
+ /**
+ * Begins the fetching process, calling the listener with every block fetched.
+ * The given message will be serialized with the Java serializer, and the RPC must return a
+ * {@link StreamHandle}. We will send all fetch requests immediately, without throttling.
+ */
+ public void start() {
+ if (blockIds.length == 0) {
+ throw new IllegalArgumentException("Zero-sized blockIds array");
+ }
+
+ client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ try {
+ streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
+ logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);
+
+ // Immediately request all chunks -- we expect that the total size of the request is
+ // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
+ for (int i = 0; i < streamHandle.numChunks; i++) {
+ client.fetchChunk(streamHandle.streamId, i, chunkCallback);
+ }
+ } catch (Exception e) {
+ logger.error("Failed while starting block fetches after success", e);
+ failRemainingBlocks(blockIds, e);
+ }
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ logger.error("Failed while starting block fetches", e);
+ failRemainingBlocks(blockIds, e);
+ }
+ });
+ }
+
+ /** Invokes the "onBlockFetchFailure" callback for every listed block id. */
+ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
+ for (String blockId : failedBlockIds) {
+ try {
+ listener.onBlockFetchFailure(blockId, e);
+ } catch (Exception e2) {
+ logger.error("Error in block fetch failure callback", e2);
+ }
+ }
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
new file mode 100644
index 0000000000..4bb0498e5d
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.LinkedHashSet;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Sets;
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Wraps another BlockFetcher with the ability to automatically retry fetches which fail due to
+ * IOExceptions, which we hope are due to transient network conditions.
+ *
+ * This fetcher provides stronger guarantees regarding the parent BlockFetchingListener. In
+ * particular, the listener will be invoked exactly once per blockId, with a success or failure.
+ */
+public class RetryingBlockFetcher {
+
+ /**
+ * Used to initiate the first fetch for all blocks, and subsequently for retrying the fetch on any
+ * remaining blocks.
+ */
+ public static interface BlockFetchStarter {
+ /**
+ * Creates a new BlockFetcher to fetch the given block ids which may do some synchronous
+ * bootstrapping followed by fully asynchronous block fetching.
+ * The BlockFetcher must eventually invoke the Listener on every input blockId, or else this
+ * method must throw an exception.
+ *
+ * This method should always attempt to get a new TransportClient from the
+ * {@link org.apache.spark.network.client.TransportClientFactory} in order to fix connection
+ * issues.
+ */
+ void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException;
+ }
+
+ /** Shared executor service used for waiting and retrying. */
+ private static final ExecutorService executorService = Executors.newCachedThreadPool(
+ NettyUtils.createThreadFactory("Block Fetch Retry"));
+
+ private final Logger logger = LoggerFactory.getLogger(RetryingBlockFetcher.class);
+
+ /** Used to initiate new Block Fetches on our remaining blocks. */
+ private final BlockFetchStarter fetchStarter;
+
+ /** Parent listener which we delegate all successful or permanently failed block fetches to. */
+ private final BlockFetchingListener listener;
+
+ /** Max number of times we are allowed to retry. */
+ private final int maxRetries;
+
+ /** Milliseconds to wait before each retry. */
+ private final int retryWaitTime;
+
+ // NOTE:
+ // All of our non-final fields are synchronized under 'this' and should only be accessed/mutated
+ // while inside a synchronized block.
+ /** Number of times we've attempted to retry so far. */
+ private int retryCount = 0;
+
+ /**
+ * Set of all block ids which have not been fetched successfully or with a non-IO Exception.
+ * A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet,
+ * input ordering is preserved, so we always request blocks in the same order the user provided.
+ */
+ private final LinkedHashSet<String> outstandingBlocksIds;
+
+ /**
+ * The BlockFetchingListener that is active with our current BlockFetcher.
+ * When we start a retry, we immediately replace this with a new Listener, which causes all any
+ * old Listeners to ignore all further responses.
+ */
+ private RetryingBlockFetchListener currentListener;
+
+ public RetryingBlockFetcher(
+ TransportConf conf,
+ BlockFetchStarter fetchStarter,
+ String[] blockIds,
+ BlockFetchingListener listener) {
+ this.fetchStarter = fetchStarter;
+ this.listener = listener;
+ this.maxRetries = conf.maxIORetries();
+ this.retryWaitTime = conf.ioRetryWaitTimeMs();
+ this.outstandingBlocksIds = Sets.newLinkedHashSet();
+ Collections.addAll(outstandingBlocksIds, blockIds);
+ this.currentListener = new RetryingBlockFetchListener();
+ }
+
+ /**
+ * Initiates the fetch of all blocks provided in the constructor, with possible retries in the
+ * event of transient IOExceptions.
+ */
+ public void start() {
+ fetchAllOutstanding();
+ }
+
+ /**
+ * Fires off a request to fetch all blocks that have not been fetched successfully or permanently
+ * failed (i.e., by a non-IOException).
+ */
+ private void fetchAllOutstanding() {
+ // Start by retrieving our shared state within a synchronized block.
+ String[] blockIdsToFetch;
+ int numRetries;
+ RetryingBlockFetchListener myListener;
+ synchronized (this) {
+ blockIdsToFetch = outstandingBlocksIds.toArray(new String[outstandingBlocksIds.size()]);
+ numRetries = retryCount;
+ myListener = currentListener;
+ }
+
+ // Now initiate the fetch on all outstanding blocks, possibly initiating a retry if that fails.
+ try {
+ fetchStarter.createAndStart(blockIdsToFetch, myListener);
+ } catch (Exception e) {
+ logger.error(String.format("Exception while beginning fetch of %s outstanding blocks %s",
+ blockIdsToFetch.length, numRetries > 0 ? "(after " + numRetries + " retries)" : ""), e);
+
+ if (shouldRetry(e)) {
+ initiateRetry();
+ } else {
+ for (String bid : blockIdsToFetch) {
+ listener.onBlockFetchFailure(bid, e);
+ }
+ }
+ }
+ }
+
+ /**
+ * Lightweight method which initiates a retry in a different thread. The retry will involve
+ * calling fetchAllOutstanding() after a configured wait time.
+ */
+ private synchronized void initiateRetry() {
+ retryCount += 1;
+ currentListener = new RetryingBlockFetchListener();
+
+ logger.info("Retrying fetch ({}/{}) for {} outstanding blocks after {} ms",
+ retryCount, maxRetries, outstandingBlocksIds.size(), retryWaitTime);
+
+ executorService.submit(new Runnable() {
+ @Override
+ public void run() {
+ Uninterruptibles.sleepUninterruptibly(retryWaitTime, TimeUnit.MILLISECONDS);
+ fetchAllOutstanding();
+ }
+ });
+ }
+
+ /**
+ * Returns true if we should retry due a block fetch failure. We will retry if and only if
+ * the exception was an IOException and we haven't retried 'maxRetries' times already.
+ */
+ private synchronized boolean shouldRetry(Throwable e) {
+ boolean isIOException = e instanceof IOException
+ || (e.getCause() != null && e.getCause() instanceof IOException);
+ boolean hasRemainingRetries = retryCount < maxRetries;
+ return isIOException && hasRemainingRetries;
+ }
+
+ /**
+ * Our RetryListener intercepts block fetch responses and forwards them to our parent listener.
+ * Note that in the event of a retry, we will immediately replace the 'currentListener' field,
+ * indicating that any responses from non-current Listeners should be ignored.
+ */
+ private class RetryingBlockFetchListener implements BlockFetchingListener {
+ @Override
+ public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
+ // We will only forward this success message to our parent listener if this block request is
+ // outstanding and we are still the active listener.
+ boolean shouldForwardSuccess = false;
+ synchronized (RetryingBlockFetcher.this) {
+ if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
+ outstandingBlocksIds.remove(blockId);
+ shouldForwardSuccess = true;
+ }
+ }
+
+ // Now actually invoke the parent listener, outside of the synchronized block.
+ if (shouldForwardSuccess) {
+ listener.onBlockFetchSuccess(blockId, data);
+ }
+ }
+
+ @Override
+ public void onBlockFetchFailure(String blockId, Throwable exception) {
+ // We will only forward this failure to our parent listener if this block request is
+ // outstanding, we are still the active listener, AND we cannot retry the fetch.
+ boolean shouldForwardFailure = false;
+ synchronized (RetryingBlockFetcher.this) {
+ if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
+ if (shouldRetry(exception)) {
+ initiateRetry();
+ } else {
+ logger.error(String.format("Failed to fetch block %s, and will not retry (%s retries)",
+ blockId, retryCount), exception);
+ outstandingBlocksIds.remove(blockId);
+ shouldForwardFailure = true;
+ }
+ }
+ }
+
+ // Now actually invoke the parent listener, outside of the synchronized block.
+ if (shouldForwardFailure) {
+ listener.onBlockFetchFailure(blockId, exception);
+ }
+ }
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
new file mode 100644
index 0000000000..f72ab40690
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.Closeable;
+
+/** Provides an interface for reading shuffle files, either from an Executor or external service. */
+public abstract class ShuffleClient implements Closeable {
+
+ /**
+ * Initializes the ShuffleClient, specifying this Executor's appId.
+ * Must be called before any other method on the ShuffleClient.
+ */
+ public void init(String appId) { }
+
+ /**
+ * Fetch a sequence of blocks from a remote node asynchronously,
+ *
+ * Note that this API takes a sequence so the implementation can batch requests, and does not
+ * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
+ * the data of a block is fetched, rather than waiting for all blocks to be fetched.
+ */
+ public abstract void fetchBlocks(
+ String host,
+ int port,
+ String execId,
+ String[] blockIds,
+ BlockFetchingListener listener);
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
new file mode 100644
index 0000000000..675820308b
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.mesos;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.shuffle.ExternalShuffleClient;
+import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A client for talking to the external shuffle service in Mesos coarse-grained mode.
+ *
+ * This is used by the Spark driver to register with each external shuffle service on the cluster.
+ * The reason why the driver has to talk to the service is for cleaning up shuffle files reliably
+ * after the application exits. Mesos does not provide a great alternative to do this, so Spark
+ * has to detect this itself.
+ */
+public class MesosExternalShuffleClient extends ExternalShuffleClient {
+ private final Logger logger = LoggerFactory.getLogger(MesosExternalShuffleClient.class);
+
+ /**
+ * Creates an Mesos external shuffle client that wraps the {@link ExternalShuffleClient}.
+ * Please refer to docs on {@link ExternalShuffleClient} for more information.
+ */
+ public MesosExternalShuffleClient(
+ TransportConf conf,
+ SecretKeyHolder secretKeyHolder,
+ boolean saslEnabled,
+ boolean saslEncryptionEnabled) {
+ super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled);
+ }
+
+ public void registerDriverWithShuffleService(String host, int port) throws IOException {
+ checkInit();
+ ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer();
+ TransportClient client = clientFactory.createClient(host, port);
+ client.sendRpc(registerDriver, new RpcResponseCallback() {
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ logger.info("Successfully registered app " + appId + " with external shuffle service.");
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ logger.warn("Unable to register app " + appId + " with external shuffle service. " +
+ "Please manually remove shuffle data after driver exit. Error: " + e);
+ }
+ });
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
new file mode 100644
index 0000000000..7fbe3384b4
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.nio.ByteBuffer;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver;
+
+/**
+ * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or
+ * by Spark's NettyBlockTransferService.
+ *
+ * At a high level:
+ * - OpenBlock is handled by both services, but only services shuffle files for the external
+ * shuffle service. It returns a StreamHandle.
+ * - UploadBlock is only handled by the NettyBlockTransferService.
+ * - RegisterExecutor is only handled by the external shuffle service.
+ */
+public abstract class BlockTransferMessage implements Encodable {
+ protected abstract Type type();
+
+ /** Preceding every serialized message is its type, which allows us to deserialize it. */
+ public static enum Type {
+ OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4);
+
+ private final byte id;
+
+ private Type(int id) {
+ assert id < 128 : "Cannot have more than 128 message types";
+ this.id = (byte) id;
+ }
+
+ public byte id() { return id; }
+ }
+
+ // NB: Java does not support static methods in interfaces, so we must put this in a static class.
+ public static class Decoder {
+ /** Deserializes the 'type' byte followed by the message itself. */
+ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) {
+ ByteBuf buf = Unpooled.wrappedBuffer(msg);
+ byte type = buf.readByte();
+ switch (type) {
+ case 0: return OpenBlocks.decode(buf);
+ case 1: return UploadBlock.decode(buf);
+ case 2: return RegisterExecutor.decode(buf);
+ case 3: return StreamHandle.decode(buf);
+ case 4: return RegisterDriver.decode(buf);
+ default: throw new IllegalArgumentException("Unknown message type: " + type);
+ }
+ }
+ }
+
+ /** Serializes the 'type' byte followed by the message itself. */
+ public ByteBuffer toByteBuffer() {
+ // Allow room for encoded message, plus the type byte
+ ByteBuf buf = Unpooled.buffer(encodedLength() + 1);
+ buf.writeByte(type().id);
+ encode(buf);
+ assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes();
+ return buf.nioBuffer();
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java
new file mode 100644
index 0000000000..102d4efb8b
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.util.Arrays;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.Encoders;
+
+/** Contains all configuration necessary for locating the shuffle files of an executor. */
+public class ExecutorShuffleInfo implements Encodable {
+ /** The base set of local directories that the executor stores its shuffle files in. */
+ public final String[] localDirs;
+ /** Number of subdirectories created within each localDir. */
+ public final int subDirsPerLocalDir;
+ /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */
+ public final String shuffleManager;
+
+ @JsonCreator
+ public ExecutorShuffleInfo(
+ @JsonProperty("localDirs") String[] localDirs,
+ @JsonProperty("subDirsPerLocalDir") int subDirsPerLocalDir,
+ @JsonProperty("shuffleManager") String shuffleManager) {
+ this.localDirs = localDirs;
+ this.subDirsPerLocalDir = subDirsPerLocalDir;
+ this.shuffleManager = shuffleManager;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(subDirsPerLocalDir, shuffleManager) * 41 + Arrays.hashCode(localDirs);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("localDirs", Arrays.toString(localDirs))
+ .add("subDirsPerLocalDir", subDirsPerLocalDir)
+ .add("shuffleManager", shuffleManager)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof ExecutorShuffleInfo) {
+ ExecutorShuffleInfo o = (ExecutorShuffleInfo) other;
+ return Arrays.equals(localDirs, o.localDirs)
+ && Objects.equal(subDirsPerLocalDir, o.subDirsPerLocalDir)
+ && Objects.equal(shuffleManager, o.shuffleManager);
+ }
+ return false;
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.StringArrays.encodedLength(localDirs)
+ + 4 // int
+ + Encoders.Strings.encodedLength(shuffleManager);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.StringArrays.encode(buf, localDirs);
+ buf.writeInt(subDirsPerLocalDir);
+ Encoders.Strings.encode(buf, shuffleManager);
+ }
+
+ public static ExecutorShuffleInfo decode(ByteBuf buf) {
+ String[] localDirs = Encoders.StringArrays.decode(buf);
+ int subDirsPerLocalDir = buf.readInt();
+ String shuffleManager = Encoders.Strings.decode(buf);
+ return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager);
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java
new file mode 100644
index 0000000000..ce954b8a28
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+/** Request to read a set of blocks. Returns {@link StreamHandle}. */
+public class OpenBlocks extends BlockTransferMessage {
+ public final String appId;
+ public final String execId;
+ public final String[] blockIds;
+
+ public OpenBlocks(String appId, String execId, String[] blockIds) {
+ this.appId = appId;
+ this.execId = execId;
+ this.blockIds = blockIds;
+ }
+
+ @Override
+ protected Type type() { return Type.OPEN_BLOCKS; }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("execId", execId)
+ .add("blockIds", Arrays.toString(blockIds))
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof OpenBlocks) {
+ OpenBlocks o = (OpenBlocks) other;
+ return Objects.equal(appId, o.appId)
+ && Objects.equal(execId, o.execId)
+ && Arrays.equals(blockIds, o.blockIds);
+ }
+ return false;
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId)
+ + Encoders.Strings.encodedLength(execId)
+ + Encoders.StringArrays.encodedLength(blockIds);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ Encoders.Strings.encode(buf, execId);
+ Encoders.StringArrays.encode(buf, blockIds);
+ }
+
+ public static OpenBlocks decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ String execId = Encoders.Strings.decode(buf);
+ String[] blockIds = Encoders.StringArrays.decode(buf);
+ return new OpenBlocks(appId, execId, blockIds);
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java
new file mode 100644
index 0000000000..167ef33104
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+/**
+ * Initial registration message between an executor and its local shuffle server.
+ * Returns nothing (empty byte array).
+ */
+public class RegisterExecutor extends BlockTransferMessage {
+ public final String appId;
+ public final String execId;
+ public final ExecutorShuffleInfo executorInfo;
+
+ public RegisterExecutor(
+ String appId,
+ String execId,
+ ExecutorShuffleInfo executorInfo) {
+ this.appId = appId;
+ this.execId = execId;
+ this.executorInfo = executorInfo;
+ }
+
+ @Override
+ protected Type type() { return Type.REGISTER_EXECUTOR; }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, execId, executorInfo);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("execId", execId)
+ .add("executorInfo", executorInfo)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof RegisterExecutor) {
+ RegisterExecutor o = (RegisterExecutor) other;
+ return Objects.equal(appId, o.appId)
+ && Objects.equal(execId, o.execId)
+ && Objects.equal(executorInfo, o.executorInfo);
+ }
+ return false;
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId)
+ + Encoders.Strings.encodedLength(execId)
+ + executorInfo.encodedLength();
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ Encoders.Strings.encode(buf, execId);
+ executorInfo.encode(buf);
+ }
+
+ public static RegisterExecutor decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ String execId = Encoders.Strings.decode(buf);
+ ExecutorShuffleInfo executorShuffleInfo = ExecutorShuffleInfo.decode(buf);
+ return new RegisterExecutor(appId, execId, executorShuffleInfo);
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java
new file mode 100644
index 0000000000..1915295aa6
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+/**
+ * Identifier for a fixed number of chunks to read from a stream created by an "open blocks"
+ * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}.
+ */
+public class StreamHandle extends BlockTransferMessage {
+ public final long streamId;
+ public final int numChunks;
+
+ public StreamHandle(long streamId, int numChunks) {
+ this.streamId = streamId;
+ this.numChunks = numChunks;
+ }
+
+ @Override
+ protected Type type() { return Type.STREAM_HANDLE; }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId, numChunks);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("numChunks", numChunks)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof StreamHandle) {
+ StreamHandle o = (StreamHandle) other;
+ return Objects.equal(streamId, o.streamId)
+ && Objects.equal(numChunks, o.numChunks);
+ }
+ return false;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(streamId);
+ buf.writeInt(numChunks);
+ }
+
+ public static StreamHandle decode(ByteBuf buf) {
+ long streamId = buf.readLong();
+ int numChunks = buf.readInt();
+ return new StreamHandle(streamId, numChunks);
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java
new file mode 100644
index 0000000000..3caed59d50
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+
+/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
+public class UploadBlock extends BlockTransferMessage {
+ public final String appId;
+ public final String execId;
+ public final String blockId;
+ // TODO: StorageLevel is serialized separately in here because StorageLevel is not available in
+ // this package. We should avoid this hack.
+ public final byte[] metadata;
+ public final byte[] blockData;
+
+ /**
+ * @param metadata Meta-information about block, typically StorageLevel.
+ * @param blockData The actual block's bytes.
+ */
+ public UploadBlock(
+ String appId,
+ String execId,
+ String blockId,
+ byte[] metadata,
+ byte[] blockData) {
+ this.appId = appId;
+ this.execId = execId;
+ this.blockId = blockId;
+ this.metadata = metadata;
+ this.blockData = blockData;
+ }
+
+ @Override
+ protected Type type() { return Type.UPLOAD_BLOCK; }
+
+ @Override
+ public int hashCode() {
+ int objectsHashCode = Objects.hashCode(appId, execId, blockId);
+ return (objectsHashCode * 41 + Arrays.hashCode(metadata)) * 41 + Arrays.hashCode(blockData);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("execId", execId)
+ .add("blockId", blockId)
+ .add("metadata size", metadata.length)
+ .add("block size", blockData.length)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof UploadBlock) {
+ UploadBlock o = (UploadBlock) other;
+ return Objects.equal(appId, o.appId)
+ && Objects.equal(execId, o.execId)
+ && Objects.equal(blockId, o.blockId)
+ && Arrays.equals(metadata, o.metadata)
+ && Arrays.equals(blockData, o.blockData);
+ }
+ return false;
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId)
+ + Encoders.Strings.encodedLength(execId)
+ + Encoders.Strings.encodedLength(blockId)
+ + Encoders.ByteArrays.encodedLength(metadata)
+ + Encoders.ByteArrays.encodedLength(blockData);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ Encoders.Strings.encode(buf, execId);
+ Encoders.Strings.encode(buf, blockId);
+ Encoders.ByteArrays.encode(buf, metadata);
+ Encoders.ByteArrays.encode(buf, blockData);
+ }
+
+ public static UploadBlock decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ String execId = Encoders.Strings.decode(buf);
+ String blockId = Encoders.Strings.decode(buf);
+ byte[] metadata = Encoders.ByteArrays.decode(buf);
+ byte[] blockData = Encoders.ByteArrays.decode(buf);
+ return new UploadBlock(appId, execId, blockId, metadata, blockData);
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java
new file mode 100644
index 0000000000..94a61d6caa
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol.mesos;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+/**
+ * A message sent from the driver to register with the MesosExternalShuffleService.
+ */
+public class RegisterDriver extends BlockTransferMessage {
+ private final String appId;
+
+ public RegisterDriver(String appId) {
+ this.appId = appId;
+ }
+
+ public String getAppId() { return appId; }
+
+ @Override
+ protected Type type() { return Type.REGISTER_DRIVER; }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId);
+ }
+
+ public static RegisterDriver decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ return new RegisterDriver(appId);
+ }
+}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
new file mode 100644
index 0000000000..0ea631ea14
--- /dev/null
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.concurrent.atomic.AtomicReference;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.shuffle.BlockFetchingListener;
+import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
+import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver;
+import org.apache.spark.network.shuffle.OneForOneBlockFetcher;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class SaslIntegrationSuite {
+
+ // Use a long timeout to account for slow / overloaded build machines. In the normal case,
+ // tests should finish way before the timeout expires.
+ private static final long TIMEOUT_MS = 10_000;
+
+ static TransportServer server;
+ static TransportConf conf;
+ static TransportContext context;
+ static SecretKeyHolder secretKeyHolder;
+
+ TransportClientFactory clientFactory;
+
+ @BeforeClass
+ public static void beforeAll() throws IOException {
+ conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+ context = new TransportContext(conf, new TestRpcHandler());
+
+ secretKeyHolder = mock(SecretKeyHolder.class);
+ when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1");
+ when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1");
+ when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2");
+ when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2");
+ when(secretKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
+ when(secretKeyHolder.getSecretKey(anyString())).thenReturn("correct-password");
+
+ TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
+ server = context.createServer(Arrays.asList(bootstrap));
+ }
+
+
+ @AfterClass
+ public static void afterAll() {
+ server.close();
+ }
+
+ @After
+ public void afterEach() {
+ if (clientFactory != null) {
+ clientFactory.close();
+ clientFactory = null;
+ }
+ }
+
+ @Test
+ public void testGoodClient() throws IOException {
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
+
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ String msg = "Hello, World!";
+ ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), TIMEOUT_MS);
+ assertEquals(msg, JavaUtils.bytesToString(resp));
+ }
+
+ @Test
+ public void testBadClient() {
+ SecretKeyHolder badKeyHolder = mock(SecretKeyHolder.class);
+ when(badKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
+ when(badKeyHolder.getSecretKey(anyString())).thenReturn("wrong-password");
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "unknown-app", badKeyHolder)));
+
+ try {
+ // Bootstrap should fail on startup.
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ fail("Connection should have failed.");
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
+ }
+ }
+
+ @Test
+ public void testNoSaslClient() throws IOException {
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList());
+
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS);
+ fail("Should have failed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage"));
+ }
+
+ try {
+ // Guessing the right tag byte doesn't magically get you in...
+ client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS);
+ fail("Should have failed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException"));
+ }
+ }
+
+ @Test
+ public void testNoSaslServer() {
+ RpcHandler handler = new TestRpcHandler();
+ TransportContext context = new TransportContext(conf, handler);
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
+ TransportServer server = context.createServer();
+ try {
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation"));
+ } finally {
+ server.close();
+ }
+ }
+
+ /**
+ * This test is not actually testing SASL behavior, but testing that the shuffle service
+ * performs correct authorization checks based on the SASL authentication data.
+ */
+ @Test
+ public void testAppIsolation() throws Exception {
+ // Start a new server with the correct RPC handler to serve block data.
+ ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class);
+ ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler(
+ new OneForOneStreamManager(), blockResolver);
+ TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
+ TransportContext blockServerContext = new TransportContext(conf, blockHandler);
+ TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
+
+ TransportClient client1 = null;
+ TransportClient client2 = null;
+ TransportClientFactory clientFactory2 = null;
+ try {
+ // Create a client, and make a request to fetch blocks from a different app.
+ clientFactory = blockServerContext.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
+ client1 = clientFactory.createClient(TestUtils.getLocalHost(),
+ blockServer.getPort());
+
+ final AtomicReference<Throwable> exception = new AtomicReference<>();
+
+ BlockFetchingListener listener = new BlockFetchingListener() {
+ @Override
+ public synchronized void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
+ notifyAll();
+ }
+
+ @Override
+ public synchronized void onBlockFetchFailure(String blockId, Throwable t) {
+ exception.set(t);
+ notifyAll();
+ }
+ };
+
+ String[] blockIds = new String[] { "shuffle_2_3_4", "shuffle_6_7_8" };
+ OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0",
+ blockIds, listener);
+ synchronized (listener) {
+ fetcher.start();
+ listener.wait();
+ }
+ checkSecurityException(exception.get());
+
+ // Register an executor so that the next steps work.
+ ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
+ new String[] { System.getProperty("java.io.tmpdir") }, 1, "sort");
+ RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
+ client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
+
+ // Make a successful request to fetch blocks, which creates a new stream. But do not actually
+ // fetch any blocks, to keep the stream open.
+ OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
+ ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS);
+ StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
+ long streamId = stream.streamId;
+
+ // Create a second client, authenticated with a different app ID, and try to read from
+ // the stream created for the previous app.
+ clientFactory2 = blockServerContext.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-2", secretKeyHolder)));
+ client2 = clientFactory2.createClient(TestUtils.getLocalHost(),
+ blockServer.getPort());
+
+ ChunkReceivedCallback callback = new ChunkReceivedCallback() {
+ @Override
+ public synchronized void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ notifyAll();
+ }
+
+ @Override
+ public synchronized void onFailure(int chunkIndex, Throwable t) {
+ exception.set(t);
+ notifyAll();
+ }
+ };
+
+ exception.set(null);
+ synchronized (callback) {
+ client2.fetchChunk(streamId, 0, callback);
+ callback.wait();
+ }
+ checkSecurityException(exception.get());
+ } finally {
+ if (client1 != null) {
+ client1.close();
+ }
+ if (client2 != null) {
+ client2.close();
+ }
+ if (clientFactory2 != null) {
+ clientFactory2.close();
+ }
+ blockServer.close();
+ }
+ }
+
+ /** RPC handler which simply responds with the message it received. */
+ public static class TestRpcHandler extends RpcHandler {
+ @Override
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
+ callback.onSuccess(message);
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return new OneForOneStreamManager();
+ }
+ }
+
+ private void checkSecurityException(Throwable t) {
+ assertNotNull("No exception was caught.", t);
+ assertTrue("Expected SecurityException.",
+ t.getMessage().contains(SecurityException.class.getName()));
+ }
+}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
new file mode 100644
index 0000000000..86c8609e70
--- /dev/null
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.shuffle.protocol.*;
+
+/** Verifies that all BlockTransferMessages can be serialized correctly. */
+public class BlockTransferMessagesSuite {
+ @Test
+ public void serializeOpenShuffleBlocks() {
+ checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" }));
+ checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo(
+ new String[] { "/local1", "/local2" }, 32, "MyShuffleManager")));
+ checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 },
+ new byte[] { 4, 5, 6, 7} ));
+ checkSerializeDeserialize(new StreamHandle(12345, 16));
+ }
+
+ private void checkSerializeDeserialize(BlockTransferMessage msg) {
+ BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer());
+ assertEquals(msg, msg2);
+ assertEquals(msg.hashCode(), msg2.hashCode());
+ assertEquals(msg.toString(), msg2.toString());
+ }
+}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
new file mode 100644
index 0000000000..9379412155
--- /dev/null
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.nio.ByteBuffer;
+import java.util.Iterator;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+
+import static org.junit.Assert.*;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
+import org.apache.spark.network.shuffle.protocol.UploadBlock;
+
+public class ExternalShuffleBlockHandlerSuite {
+ TransportClient client = mock(TransportClient.class);
+
+ OneForOneStreamManager streamManager;
+ ExternalShuffleBlockResolver blockResolver;
+ RpcHandler handler;
+
+ @Before
+ public void beforeEach() {
+ streamManager = mock(OneForOneStreamManager.class);
+ blockResolver = mock(ExternalShuffleBlockResolver.class);
+ handler = new ExternalShuffleBlockHandler(streamManager, blockResolver);
+ }
+
+ @Test
+ public void testRegisterExecutor() {
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+
+ ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort");
+ ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer();
+ handler.receive(client, registerMessage, callback);
+ verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config);
+
+ verify(callback, times(1)).onSuccess(any(ByteBuffer.class));
+ verify(callback, never()).onFailure(any(Throwable.class));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testOpenShuffleBlocks() {
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+
+ ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3]));
+ ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
+ when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker);
+ when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker);
+ ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" })
+ .toByteBuffer();
+ handler.receive(client, openBlocks, callback);
+ verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0");
+ verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1");
+
+ ArgumentCaptor<ByteBuffer> response = ArgumentCaptor.forClass(ByteBuffer.class);
+ verify(callback, times(1)).onSuccess(response.capture());
+ verify(callback, never()).onFailure((Throwable) any());
+
+ StreamHandle handle =
+ (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue());
+ assertEquals(2, handle.numChunks);
+
+ @SuppressWarnings("unchecked")
+ ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
+ (ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
+ verify(streamManager, times(1)).registerStream(anyString(), stream.capture());
+ Iterator<ManagedBuffer> buffers = stream.getValue();
+ assertEquals(block0Marker, buffers.next());
+ assertEquals(block1Marker, buffers.next());
+ assertFalse(buffers.hasNext());
+ }
+
+ @Test
+ public void testBadMessages() {
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+
+ ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 });
+ try {
+ handler.receive(client, unserializableMsg, callback);
+ fail("Should have thrown");
+ } catch (Exception e) {
+ // pass
+ }
+
+ ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer();
+ try {
+ handler.receive(client, unexpectedMsg, callback);
+ fail("Should have thrown");
+ } catch (UnsupportedOperationException e) {
+ // pass
+ }
+
+ verify(callback, never()).onSuccess(any(ByteBuffer.class));
+ verify(callback, never()).onFailure(any(Throwable.class));
+ }
+}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
new file mode 100644
index 0000000000..60a1b8b045
--- /dev/null
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.io.CharStreams;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+public class ExternalShuffleBlockResolverSuite {
+ static String sortBlock0 = "Hello!";
+ static String sortBlock1 = "World!";
+
+ static String hashBlock0 = "Elementary";
+ static String hashBlock1 = "Tabular";
+
+ static TestShuffleDataContext dataContext;
+
+ static TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+
+ @BeforeClass
+ public static void beforeAll() throws IOException {
+ dataContext = new TestShuffleDataContext(2, 5);
+
+ dataContext.create();
+ // Write some sort and hash data.
+ dataContext.insertSortShuffleData(0, 0,
+ new byte[][] { sortBlock0.getBytes(), sortBlock1.getBytes() } );
+ dataContext.insertHashShuffleData(1, 0,
+ new byte[][] { hashBlock0.getBytes(), hashBlock1.getBytes() } );
+ }
+
+ @AfterClass
+ public static void afterAll() {
+ dataContext.cleanup();
+ }
+
+ @Test
+ public void testBadRequests() throws IOException {
+ ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null);
+ // Unregistered executor
+ try {
+ resolver.getBlockData("app0", "exec1", "shuffle_1_1_0");
+ fail("Should have failed");
+ } catch (RuntimeException e) {
+ assertTrue("Bad error message: " + e, e.getMessage().contains("not registered"));
+ }
+
+ // Invalid shuffle manager
+ resolver.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar"));
+ try {
+ resolver.getBlockData("app0", "exec2", "shuffle_1_1_0");
+ fail("Should have failed");
+ } catch (UnsupportedOperationException e) {
+ // pass
+ }
+
+ // Nonexistent shuffle block
+ resolver.registerExecutor("app0", "exec3",
+ dataContext.createExecutorInfo("sort"));
+ try {
+ resolver.getBlockData("app0", "exec3", "shuffle_1_1_0");
+ fail("Should have failed");
+ } catch (Exception e) {
+ // pass
+ }
+ }
+
+ @Test
+ public void testSortShuffleBlocks() throws IOException {
+ ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null);
+ resolver.registerExecutor("app0", "exec0",
+ dataContext.createExecutorInfo("sort"));
+
+ InputStream block0Stream =
+ resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream();
+ String block0 = CharStreams.toString(new InputStreamReader(block0Stream));
+ block0Stream.close();
+ assertEquals(sortBlock0, block0);
+
+ InputStream block1Stream =
+ resolver.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream();
+ String block1 = CharStreams.toString(new InputStreamReader(block1Stream));
+ block1Stream.close();
+ assertEquals(sortBlock1, block1);
+ }
+
+ @Test
+ public void testHashShuffleBlocks() throws IOException {
+ ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null);
+ resolver.registerExecutor("app0", "exec0",
+ dataContext.createExecutorInfo("hash"));
+
+ InputStream block0Stream =
+ resolver.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream();
+ String block0 = CharStreams.toString(new InputStreamReader(block0Stream));
+ block0Stream.close();
+ assertEquals(hashBlock0, block0);
+
+ InputStream block1Stream =
+ resolver.getBlockData("app0", "exec0", "shuffle_1_0_1").createInputStream();
+ String block1 = CharStreams.toString(new InputStreamReader(block1Stream));
+ block1Stream.close();
+ assertEquals(hashBlock1, block1);
+ }
+
+ @Test
+ public void jsonSerializationOfExecutorRegistration() throws IOException {
+ ObjectMapper mapper = new ObjectMapper();
+ AppExecId appId = new AppExecId("foo", "bar");
+ String appIdJson = mapper.writeValueAsString(appId);
+ AppExecId parsedAppId = mapper.readValue(appIdJson, AppExecId.class);
+ assertEquals(parsedAppId, appId);
+
+ ExecutorShuffleInfo shuffleInfo =
+ new ExecutorShuffleInfo(new String[]{"/bippy", "/flippy"}, 7, "hash");
+ String shuffleJson = mapper.writeValueAsString(shuffleInfo);
+ ExecutorShuffleInfo parsedShuffleInfo =
+ mapper.readValue(shuffleJson, ExecutorShuffleInfo.class);
+ assertEquals(parsedShuffleInfo, shuffleInfo);
+
+ // Intentionally keep these hard-coded strings in here, to check backwards-compatability.
+ // its not legacy yet, but keeping this here in case anybody changes it
+ String legacyAppIdJson = "{\"appId\":\"foo\", \"execId\":\"bar\"}";
+ assertEquals(appId, mapper.readValue(legacyAppIdJson, AppExecId.class));
+ String legacyShuffleJson = "{\"localDirs\": [\"/bippy\", \"/flippy\"], " +
+ "\"subDirsPerLocalDir\": 7, \"shuffleManager\": \"hash\"}";
+ assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class));
+ }
+}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java
new file mode 100644
index 0000000000..532d7ab8d0
--- /dev/null
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Random;
+import java.util.concurrent.Executor;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import com.google.common.util.concurrent.MoreExecutors;
+import org.junit.Test;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class ExternalShuffleCleanupSuite {
+
+ // Same-thread Executor used to ensure cleanup happens synchronously in test thread.
+ Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor();
+ TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+
+ @Test
+ public void noCleanupAndCleanup() throws IOException {
+ TestShuffleDataContext dataContext = createSomeData();
+
+ ExternalShuffleBlockResolver resolver =
+ new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor);
+ resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr"));
+ resolver.applicationRemoved("app", false /* cleanup */);
+
+ assertStillThere(dataContext);
+
+ resolver.registerExecutor("app", "exec1", dataContext.createExecutorInfo("shuffleMgr"));
+ resolver.applicationRemoved("app", true /* cleanup */);
+
+ assertCleanedUp(dataContext);
+ }
+
+ @Test
+ public void cleanupUsesExecutor() throws IOException {
+ TestShuffleDataContext dataContext = createSomeData();
+
+ final AtomicBoolean cleanupCalled = new AtomicBoolean(false);
+
+ // Executor which does nothing to ensure we're actually using it.
+ Executor noThreadExecutor = new Executor() {
+ @Override public void execute(Runnable runnable) { cleanupCalled.set(true); }
+ };
+
+ ExternalShuffleBlockResolver manager =
+ new ExternalShuffleBlockResolver(conf, null, noThreadExecutor);
+
+ manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr"));
+ manager.applicationRemoved("app", true);
+
+ assertTrue(cleanupCalled.get());
+ assertStillThere(dataContext);
+
+ dataContext.cleanup();
+ assertCleanedUp(dataContext);
+ }
+
+ @Test
+ public void cleanupMultipleExecutors() throws IOException {
+ TestShuffleDataContext dataContext0 = createSomeData();
+ TestShuffleDataContext dataContext1 = createSomeData();
+
+ ExternalShuffleBlockResolver resolver =
+ new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor);
+
+ resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr"));
+ resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr"));
+ resolver.applicationRemoved("app", true);
+
+ assertCleanedUp(dataContext0);
+ assertCleanedUp(dataContext1);
+ }
+
+ @Test
+ public void cleanupOnlyRemovedApp() throws IOException {
+ TestShuffleDataContext dataContext0 = createSomeData();
+ TestShuffleDataContext dataContext1 = createSomeData();
+
+ ExternalShuffleBlockResolver resolver =
+ new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor);
+
+ resolver.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr"));
+ resolver.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr"));
+
+ resolver.applicationRemoved("app-nonexistent", true);
+ assertStillThere(dataContext0);
+ assertStillThere(dataContext1);
+
+ resolver.applicationRemoved("app-0", true);
+ assertCleanedUp(dataContext0);
+ assertStillThere(dataContext1);
+
+ resolver.applicationRemoved("app-1", true);
+ assertCleanedUp(dataContext0);
+ assertCleanedUp(dataContext1);
+
+ // Make sure it's not an error to cleanup multiple times
+ resolver.applicationRemoved("app-1", true);
+ assertCleanedUp(dataContext0);
+ assertCleanedUp(dataContext1);
+ }
+
+ private void assertStillThere(TestShuffleDataContext dataContext) {
+ for (String localDir : dataContext.localDirs) {
+ assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists());
+ }
+ }
+
+ private void assertCleanedUp(TestShuffleDataContext dataContext) {
+ for (String localDir : dataContext.localDirs) {
+ assertFalse(localDir + " wasn't cleaned up", new File(localDir).exists());
+ }
+ }
+
+ private TestShuffleDataContext createSomeData() throws IOException {
+ Random rand = new Random(123);
+ TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5);
+
+ dataContext.create();
+ dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000),
+ new byte[][] { "ABC".getBytes(), "DEF".getBytes() } );
+ dataContext.insertHashShuffleData(rand.nextInt(1000), rand.nextInt(1000) + 1000,
+ new byte[][] { "GHI".getBytes(), "JKLMNOPQRSTUVWXYZ".getBytes() } );
+ return dataContext;
+ }
+}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
new file mode 100644
index 0000000000..5e706bf401
--- /dev/null
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -0,0 +1,301 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class ExternalShuffleIntegrationSuite {
+
+ static String APP_ID = "app-id";
+ static String SORT_MANAGER = "sort";
+ static String HASH_MANAGER = "hash";
+
+ // Executor 0 is sort-based
+ static TestShuffleDataContext dataContext0;
+ // Executor 1 is hash-based
+ static TestShuffleDataContext dataContext1;
+
+ static ExternalShuffleBlockHandler handler;
+ static TransportServer server;
+ static TransportConf conf;
+
+ static byte[][] exec0Blocks = new byte[][] {
+ new byte[123],
+ new byte[12345],
+ new byte[1234567],
+ };
+
+ static byte[][] exec1Blocks = new byte[][] {
+ new byte[321],
+ new byte[54321],
+ };
+
+ @BeforeClass
+ public static void beforeAll() throws IOException {
+ Random rand = new Random();
+
+ for (byte[] block : exec0Blocks) {
+ rand.nextBytes(block);
+ }
+ for (byte[] block: exec1Blocks) {
+ rand.nextBytes(block);
+ }
+
+ dataContext0 = new TestShuffleDataContext(2, 5);
+ dataContext0.create();
+ dataContext0.insertSortShuffleData(0, 0, exec0Blocks);
+
+ dataContext1 = new TestShuffleDataContext(6, 2);
+ dataContext1.create();
+ dataContext1.insertHashShuffleData(1, 0, exec1Blocks);
+
+ conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+ handler = new ExternalShuffleBlockHandler(conf, null);
+ TransportContext transportContext = new TransportContext(conf, handler);
+ server = transportContext.createServer();
+ }
+
+ @AfterClass
+ public static void afterAll() {
+ dataContext0.cleanup();
+ dataContext1.cleanup();
+ server.close();
+ }
+
+ @After
+ public void afterEach() {
+ handler.applicationRemoved(APP_ID, false /* cleanupLocalDirs */);
+ }
+
+ class FetchResult {
+ public Set<String> successBlocks;
+ public Set<String> failedBlocks;
+ public List<ManagedBuffer> buffers;
+
+ public void releaseBuffers() {
+ for (ManagedBuffer buffer : buffers) {
+ buffer.release();
+ }
+ }
+ }
+
+ // Fetch a set of blocks from a pre-registered executor.
+ private FetchResult fetchBlocks(String execId, String[] blockIds) throws Exception {
+ return fetchBlocks(execId, blockIds, server.getPort());
+ }
+
+ // Fetch a set of blocks from a pre-registered executor. Connects to the server on the given port,
+ // to allow connecting to invalid servers.
+ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) throws Exception {
+ final FetchResult res = new FetchResult();
+ res.successBlocks = Collections.synchronizedSet(new HashSet<String>());
+ res.failedBlocks = Collections.synchronizedSet(new HashSet<String>());
+ res.buffers = Collections.synchronizedList(new LinkedList<ManagedBuffer>());
+
+ final Semaphore requestsRemaining = new Semaphore(0);
+
+ ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false);
+ client.init(APP_ID);
+ client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
+ new BlockFetchingListener() {
+ @Override
+ public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
+ synchronized (this) {
+ if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) {
+ data.retain();
+ res.successBlocks.add(blockId);
+ res.buffers.add(data);
+ requestsRemaining.release();
+ }
+ }
+ }
+
+ @Override
+ public void onBlockFetchFailure(String blockId, Throwable exception) {
+ synchronized (this) {
+ if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) {
+ res.failedBlocks.add(blockId);
+ requestsRemaining.release();
+ }
+ }
+ }
+ });
+
+ if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server");
+ }
+ client.close();
+ return res;
+ }
+
+ @Test
+ public void testFetchOneSort() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult exec0Fetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" });
+ assertEquals(Sets.newHashSet("shuffle_0_0_0"), exec0Fetch.successBlocks);
+ assertTrue(exec0Fetch.failedBlocks.isEmpty());
+ assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks[0]));
+ exec0Fetch.releaseBuffers();
+ }
+
+ @Test
+ public void testFetchThreeSort() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult exec0Fetch = fetchBlocks("exec-0",
+ new String[] { "shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2" });
+ assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2"),
+ exec0Fetch.successBlocks);
+ assertTrue(exec0Fetch.failedBlocks.isEmpty());
+ assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks));
+ exec0Fetch.releaseBuffers();
+ }
+
+ @Test
+ public void testFetchHash() throws Exception {
+ registerExecutor("exec-1", dataContext1.createExecutorInfo(HASH_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-1",
+ new String[] { "shuffle_1_0_0", "shuffle_1_0_1" });
+ assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.successBlocks);
+ assertTrue(execFetch.failedBlocks.isEmpty());
+ assertBufferListsEqual(execFetch.buffers, Lists.newArrayList(exec1Blocks));
+ execFetch.releaseBuffers();
+ }
+
+ @Test
+ public void testFetchWrongShuffle() throws Exception {
+ registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */));
+ FetchResult execFetch = fetchBlocks("exec-1",
+ new String[] { "shuffle_1_0_0", "shuffle_1_0_1" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchInvalidShuffle() throws Exception {
+ registerExecutor("exec-1", dataContext1.createExecutorInfo("unknown sort manager"));
+ FetchResult execFetch = fetchBlocks("exec-1",
+ new String[] { "shuffle_1_0_0" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchWrongBlockId() throws Exception {
+ registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */));
+ FetchResult execFetch = fetchBlocks("exec-1",
+ new String[] { "rdd_1_0_0" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("rdd_1_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchNonexistent() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-0",
+ new String[] { "shuffle_2_0_0" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_2_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchWrongExecutor() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-0",
+ new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ });
+ // Both still fail, as we start by checking for all block.
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchUnregisteredExecutor() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-2",
+ new String[] { "shuffle_0_0_0", "shuffle_1_0_0" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchNoServer() throws Exception {
+ System.setProperty("spark.shuffle.io.maxRetries", "0");
+ try {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-0",
+ new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, 1 /* port */);
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks);
+ } finally {
+ System.clearProperty("spark.shuffle.io.maxRetries");
+ }
+ }
+
+ private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo)
+ throws IOException {
+ ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false);
+ client.init(APP_ID);
+ client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
+ executorId, executorInfo);
+ }
+
+ private void assertBufferListsEqual(List<ManagedBuffer> list0, List<byte[]> list1)
+ throws Exception {
+ assertEquals(list0.size(), list1.size());
+ for (int i = 0; i < list0.size(); i ++) {
+ assertBuffersEqual(list0.get(i), new NioManagedBuffer(ByteBuffer.wrap(list1.get(i))));
+ }
+ }
+
+ private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception {
+ ByteBuffer nio0 = buffer0.nioByteBuffer();
+ ByteBuffer nio1 = buffer1.nioByteBuffer();
+
+ int len = nio0.remaining();
+ assertEquals(nio0.remaining(), nio1.remaining());
+ for (int i = 0; i < len; i ++) {
+ assertEquals(nio0.get(), nio1.get());
+ }
+ }
+}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
new file mode 100644
index 0000000000..08ddb3755b
--- /dev/null
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.sasl.SaslServerBootstrap;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class ExternalShuffleSecuritySuite {
+
+ TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+ TransportServer server;
+
+ @Before
+ public void beforeEach() throws IOException {
+ TransportContext context =
+ new TransportContext(conf, new ExternalShuffleBlockHandler(conf, null));
+ TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf,
+ new TestSecretKeyHolder("my-app-id", "secret"));
+ this.server = context.createServer(Arrays.asList(bootstrap));
+ }
+
+ @After
+ public void afterEach() {
+ if (server != null) {
+ server.close();
+ server = null;
+ }
+ }
+
+ @Test
+ public void testValid() throws IOException {
+ validate("my-app-id", "secret", false);
+ }
+
+ @Test
+ public void testBadAppId() {
+ try {
+ validate("wrong-app-id", "secret", false);
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!"));
+ }
+ }
+
+ @Test
+ public void testBadSecret() {
+ try {
+ validate("my-app-id", "bad-secret", false);
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
+ }
+ }
+
+ @Test
+ public void testEncryption() throws IOException {
+ validate("my-app-id", "secret", true);
+ }
+
+ /** Creates an ExternalShuffleClient and attempts to register with the server. */
+ private void validate(String appId, String secretKey, boolean encrypt) throws IOException {
+ ExternalShuffleClient client =
+ new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true, encrypt);
+ client.init(appId);
+ // Registration either succeeds or throws an exception.
+ client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0",
+ new ExecutorShuffleInfo(new String[0], 0, ""));
+ client.close();
+ }
+
+ /** Provides a secret key holder which always returns the given secret key, for a single appId. */
+ static class TestSecretKeyHolder implements SecretKeyHolder {
+ private final String appId;
+ private final String secretKey;
+
+ TestSecretKeyHolder(String appId, String secretKey) {
+ this.appId = appId;
+ this.secretKey = secretKey;
+ }
+
+ @Override
+ public String getSaslUser(String appId) {
+ return "user";
+ }
+
+ @Override
+ public String getSecretKey(String appId) {
+ if (!appId.equals(this.appId)) {
+ throw new IllegalArgumentException("Wrong appId!");
+ }
+ return secretKey;
+ }
+ }
+}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
new file mode 100644
index 0000000000..2590b9ce4c
--- /dev/null
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.nio.ByteBuffer;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.Maps;
+import io.netty.buffer.Unpooled;
+import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Matchers.anyLong;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
+
+public class OneForOneBlockFetcherSuite {
+ @Test
+ public void testFetchOne() {
+ LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+ blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
+
+ BlockFetchingListener listener = fetchBlocks(blocks);
+
+ verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0"));
+ }
+
+ @Test
+ public void testFetchThree() {
+ LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+ blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+ blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
+ blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
+
+ BlockFetchingListener listener = fetchBlocks(blocks);
+
+ for (int i = 0; i < 3; i ++) {
+ verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i));
+ }
+ }
+
+ @Test
+ public void testFailure() {
+ LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+ blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+ blocks.put("b1", null);
+ blocks.put("b2", null);
+
+ BlockFetchingListener listener = fetchBlocks(blocks);
+
+ // Each failure will cause a failure to be invoked in all remaining block fetches.
+ verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0"));
+ verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any());
+ verify(listener, times(2)).onBlockFetchFailure(eq("b2"), (Throwable) any());
+ }
+
+ @Test
+ public void testFailureAndSuccess() {
+ LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+ blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+ blocks.put("b1", null);
+ blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21])));
+
+ BlockFetchingListener listener = fetchBlocks(blocks);
+
+ // We may call both success and failure for the same block.
+ verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0"));
+ verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any());
+ verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2"));
+ verify(listener, times(1)).onBlockFetchFailure(eq("b2"), (Throwable) any());
+ }
+
+ @Test
+ public void testEmptyBlockFetch() {
+ try {
+ fetchBlocks(Maps.<String, ManagedBuffer>newLinkedHashMap());
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertEquals("Zero-sized blockIds array", e.getMessage());
+ }
+ }
+
+ /**
+ * Begins a fetch on the given set of blocks by mocking out the server side of the RPC which
+ * simply returns the given (BlockId, Block) pairs.
+ * As "blocks" is a LinkedHashMap, the blocks are guaranteed to be returned in the same order
+ * that they were inserted in.
+ *
+ * If a block's buffer is "null", an exception will be thrown instead.
+ */
+ private BlockFetchingListener fetchBlocks(final LinkedHashMap<String, ManagedBuffer> blocks) {
+ TransportClient client = mock(TransportClient.class);
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+ final String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
+ OneForOneBlockFetcher fetcher =
+ new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener);
+
+ // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+ BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer(
+ (ByteBuffer) invocationOnMock.getArguments()[0]);
+ RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1];
+ callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer());
+ assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message);
+ return null;
+ }
+ }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class));
+
+ // Respond to each chunk request with a single buffer from our blocks array.
+ final AtomicInteger expectedChunkIndex = new AtomicInteger(0);
+ final Iterator<ManagedBuffer> blockIterator = blocks.values().iterator();
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocation) throws Throwable {
+ try {
+ long streamId = (Long) invocation.getArguments()[0];
+ int myChunkIndex = (Integer) invocation.getArguments()[1];
+ assertEquals(123, streamId);
+ assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex);
+
+ ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2];
+ ManagedBuffer result = blockIterator.next();
+ if (result != null) {
+ callback.onSuccess(myChunkIndex, result);
+ } else {
+ callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex));
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ fail("Unexpected failure");
+ }
+ return null;
+ }
+ }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any());
+
+ fetcher.start();
+ return listener;
+ }
+}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
new file mode 100644
index 0000000000..3a6ef0d3f8
--- /dev/null
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
@@ -0,0 +1,313 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Sets;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import org.mockito.stubbing.Stubber;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+import static org.apache.spark.network.shuffle.RetryingBlockFetcher.BlockFetchStarter;
+
+/**
+ * Tests retry logic by throwing IOExceptions and ensuring that subsequent attempts are made to
+ * fetch the lost blocks.
+ */
+public class RetryingBlockFetcherSuite {
+
+ ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13]));
+ ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
+ ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19]));
+
+ @Before
+ public void beforeEach() {
+ System.setProperty("spark.shuffle.io.maxRetries", "2");
+ System.setProperty("spark.shuffle.io.retryWait", "0");
+ }
+
+ @After
+ public void afterEach() {
+ System.clearProperty("spark.shuffle.io.maxRetries");
+ System.clearProperty("spark.shuffle.io.retryWait");
+ }
+
+ @Test
+ public void testNoFailures() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ List<? extends Map<String, Object>> interactions = Arrays.asList(
+ // Immediately return both blocks successfully.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", block1)
+ .build()
+ );
+
+ performInteractions(interactions, listener);
+
+ verify(listener).onBlockFetchSuccess("b0", block0);
+ verify(listener).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testUnrecoverableFailure() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ List<? extends Map<String, Object>> interactions = Arrays.asList(
+ // b0 throws a non-IOException error, so it will be failed without retry.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new RuntimeException("Ouch!"))
+ .put("b1", block1)
+ .build()
+ );
+
+ performInteractions(interactions, listener);
+
+ verify(listener).onBlockFetchFailure(eq("b0"), (Throwable) any());
+ verify(listener).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testSingleIOExceptionOnFirst() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ List<? extends Map<String, Object>> interactions = Arrays.asList(
+ // IOException will cause a retry. Since b0 fails, we will retry both.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new IOException("Connection failed or something"))
+ .put("b1", block1)
+ .build(),
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", block1)
+ .build()
+ );
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testSingleIOExceptionOnSecond() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ List<? extends Map<String, Object>> interactions = Arrays.asList(
+ // IOException will cause a retry. Since b1 fails, we will not retry b0.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", new IOException("Connection failed or something"))
+ .build(),
+ ImmutableMap.<String, Object>builder()
+ .put("b1", block1)
+ .build()
+ );
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testTwoIOExceptions() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ List<? extends Map<String, Object>> interactions = Arrays.asList(
+ // b0's IOException will trigger retry, b1's will be ignored.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new IOException())
+ .put("b1", new IOException())
+ .build(),
+ // Next, b0 is successful and b1 errors again, so we just request that one.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", new IOException())
+ .build(),
+ // b1 returns successfully within 2 retries.
+ ImmutableMap.<String, Object>builder()
+ .put("b1", block1)
+ .build()
+ );
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1);
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testThreeIOExceptions() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ List<? extends Map<String, Object>> interactions = Arrays.asList(
+ // b0's IOException will trigger retry, b1's will be ignored.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new IOException())
+ .put("b1", new IOException())
+ .build(),
+ // Next, b0 is successful and b1 errors again, so we just request that one.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", new IOException())
+ .build(),
+ // b1 errors again, but this was the last retry
+ ImmutableMap.<String, Object>builder()
+ .put("b1", new IOException())
+ .build(),
+ // This is not reached -- b1 has failed.
+ ImmutableMap.<String, Object>builder()
+ .put("b1", block1)
+ .build()
+ );
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any());
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void testRetryAndUnrecoverable() throws IOException {
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+ List<? extends Map<String, Object>> interactions = Arrays.asList(
+ // b0's IOException will trigger retry, subsequent messages will be ignored.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", new IOException())
+ .put("b1", new RuntimeException())
+ .put("b2", block2)
+ .build(),
+ // Next, b0 is successful, b1 errors unrecoverably, and b2 triggers a retry.
+ ImmutableMap.<String, Object>builder()
+ .put("b0", block0)
+ .put("b1", new RuntimeException())
+ .put("b2", new IOException())
+ .build(),
+ // b2 succeeds in its last retry.
+ ImmutableMap.<String, Object>builder()
+ .put("b2", block2)
+ .build()
+ );
+
+ performInteractions(interactions, listener);
+
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+ verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any());
+ verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2);
+ verifyNoMoreInteractions(listener);
+ }
+
+ /**
+ * Performs a set of interactions in response to block requests from a RetryingBlockFetcher.
+ * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction
+ * means "respond to the next block fetch request with these Successful buffers and these Failure
+ * exceptions". We verify that the expected block ids are exactly the ones requested.
+ *
+ * If multiple interactions are supplied, they will be used in order. This is useful for encoding
+ * retries -- the first interaction may include an IOException, which causes a retry of some
+ * subset of the original blocks in a second interaction.
+ */
+ @SuppressWarnings("unchecked")
+ private static void performInteractions(List<? extends Map<String, Object>> interactions,
+ BlockFetchingListener listener)
+ throws IOException {
+
+ TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
+ BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class);
+
+ Stubber stub = null;
+
+ // Contains all blockIds that are referenced across all interactions.
+ final LinkedHashSet<String> blockIds = Sets.newLinkedHashSet();
+
+ for (final Map<String, Object> interaction : interactions) {
+ blockIds.addAll(interaction.keySet());
+
+ Answer<Void> answer = new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+ try {
+ // Verify that the RetryingBlockFetcher requested the expected blocks.
+ String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0];
+ String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]);
+ assertArrayEquals(desiredBlockIds, requestedBlockIds);
+
+ // Now actually invoke the success/failure callbacks on each block.
+ BlockFetchingListener retryListener =
+ (BlockFetchingListener) invocationOnMock.getArguments()[1];
+ for (Map.Entry<String, Object> block : interaction.entrySet()) {
+ String blockId = block.getKey();
+ Object blockValue = block.getValue();
+
+ if (blockValue instanceof ManagedBuffer) {
+ retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue);
+ } else if (blockValue instanceof Exception) {
+ retryListener.onBlockFetchFailure(blockId, (Exception) blockValue);
+ } else {
+ fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue);
+ }
+ }
+ return null;
+ } catch (Throwable e) {
+ e.printStackTrace();
+ throw e;
+ }
+ }
+ };
+
+ // This is either the first stub, or should be chained behind the prior ones.
+ if (stub == null) {
+ stub = doAnswer(answer);
+ } else {
+ stub.doAnswer(answer);
+ }
+ }
+
+ assert stub != null;
+ stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject());
+ String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]);
+ new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start();
+ }
+}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
new file mode 100644
index 0000000000..7ac1ca128a
--- /dev/null
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+
+/**
+ * Manages some sort- and hash-based shuffle data, including the creation
+ * and cleanup of directories that can be read by the {@link ExternalShuffleBlockResolver}.
+ */
+public class TestShuffleDataContext {
+ public final String[] localDirs;
+ public final int subDirsPerLocalDir;
+
+ public TestShuffleDataContext(int numLocalDirs, int subDirsPerLocalDir) {
+ this.localDirs = new String[numLocalDirs];
+ this.subDirsPerLocalDir = subDirsPerLocalDir;
+ }
+
+ public void create() {
+ for (int i = 0; i < localDirs.length; i ++) {
+ localDirs[i] = Files.createTempDir().getAbsolutePath();
+
+ for (int p = 0; p < subDirsPerLocalDir; p ++) {
+ new File(localDirs[i], String.format("%02x", p)).mkdirs();
+ }
+ }
+ }
+
+ public void cleanup() {
+ for (String localDir : localDirs) {
+ deleteRecursively(new File(localDir));
+ }
+ }
+
+ /** Creates reducer blocks in a sort-based data format within our local dirs. */
+ public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException {
+ String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0";
+
+ OutputStream dataStream = null;
+ DataOutputStream indexStream = null;
+ boolean suppressExceptionsDuringClose = true;
+
+ try {
+ dataStream = new FileOutputStream(
+ ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data"));
+ indexStream = new DataOutputStream(new FileOutputStream(
+ ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index")));
+
+ long offset = 0;
+ indexStream.writeLong(offset);
+ for (byte[] block : blocks) {
+ offset += block.length;
+ dataStream.write(block);
+ indexStream.writeLong(offset);
+ }
+ suppressExceptionsDuringClose = false;
+ } finally {
+ Closeables.close(dataStream, suppressExceptionsDuringClose);
+ Closeables.close(indexStream, suppressExceptionsDuringClose);
+ }
+ }
+
+ /** Creates reducer blocks in a hash-based data format within our local dirs. */
+ public void insertHashShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException {
+ for (int i = 0; i < blocks.length; i ++) {
+ String blockId = "shuffle_" + shuffleId + "_" + mapId + "_" + i;
+ Files.write(blocks[i],
+ ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId));
+ }
+ }
+
+ /**
+ * Creates an ExecutorShuffleInfo object based on the given shuffle manager which targets this
+ * context's directories.
+ */
+ public ExecutorShuffleInfo createExecutorInfo(String shuffleManager) {
+ return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager);
+ }
+
+ private static void deleteRecursively(File f) {
+ assert f != null;
+ if (f.isDirectory()) {
+ File[] children = f.listFiles();
+ if (children != null) {
+ for (File child : children) {
+ deleteRecursively(child);
+ }
+ }
+ }
+ f.delete();
+ }
+}