aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala2
-rw-r--r--network/shuffle/pom.xml16
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java37
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java225
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java8
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java35
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java9
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java2
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java5
-rw-r--r--network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java62
-rw-r--r--pom.xml5
-rw-r--r--yarn/pom.xml6
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala193
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala173
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala109
-rw-r--r--yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala71
-rw-r--r--yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala233
-rw-r--r--yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala37
21 files changed, 1031 insertions, 215 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
index 20a9faa178..22ef701d83 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
@@ -53,7 +53,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
/** Create a new shuffle block handler. Factored out for subclasses to override. */
protected def newShuffleBlockHandler(conf: TransportConf): ExternalShuffleBlockHandler = {
- new ExternalShuffleBlockHandler(conf)
+ new ExternalShuffleBlockHandler(conf, null)
}
/** Starts the external shuffle service if the user has configured us to. */
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
index 061857476a..12337a940a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
@@ -34,7 +34,7 @@ import org.apache.spark.network.util.TransportConf
* It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]].
*/
private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf)
- extends ExternalShuffleBlockHandler(transportConf) with Logging {
+ extends ExternalShuffleBlockHandler(transportConf, null) with Logging {
// Stores a map of driver socket addresses to app ids
private val connectedApps = new mutable.HashMap[SocketAddress, String]
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index eedb27942e..fefaef0ab8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -93,8 +93,17 @@ private[spark] class BlockManager(
// Port used by the external shuffle service. In Yarn mode, this may be already be
// set through the Hadoop configuration as the server is launched in the Yarn NM.
- private val externalShuffleServicePort =
- Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt
+ private val externalShuffleServicePort = {
+ val tmpPort = Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt
+ if (tmpPort == 0) {
+ // for testing, we set "spark.shuffle.service.port" to 0 in the yarn config, so yarn finds
+ // an open port. But we still need to tell our spark apps the right port to use. So
+ // only if the yarn config has the port set to 0, we prefer the value in the spark config
+ conf.get("spark.shuffle.service.port").toInt
+ } else {
+ tmpPort
+ }
+ }
// Check that we're not using external shuffle service with consolidated shuffle files.
if (externalShuffleServiceEnabled
@@ -191,6 +200,7 @@ private[spark] class BlockManager(
executorId, blockTransferService.hostName, blockTransferService.port)
shuffleServerId = if (externalShuffleServiceEnabled) {
+ logInfo(s"external shuffle service port = $externalShuffleServicePort")
BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
} else {
blockManagerId
diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
index c38d70252a..e846a72c88 100644
--- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
@@ -36,7 +36,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll {
override def beforeAll() {
val transportConf = SparkTransportConf.fromSparkConf(conf, numUsableCores = 2)
- rpcHandler = new ExternalShuffleBlockHandler(transportConf)
+ rpcHandler = new ExternalShuffleBlockHandler(transportConf, null)
val transportContext = new TransportContext(transportConf, rpcHandler)
server = transportContext.createServer()
diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml
index 532463e96f..3d2edf9d94 100644
--- a/network/shuffle/pom.xml
+++ b/network/shuffle/pom.xml
@@ -43,6 +43,22 @@
<version>${project.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.fusesource.leveldbjni</groupId>
+ <artifactId>leveldbjni-all</artifactId>
+ <version>1.8</version>
+ </dependency>
+
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-annotations</artifactId>
+ </dependency>
+
<!-- Provided dependencies -->
<dependency>
<groupId>org.slf4j</groupId>
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index db9dc4f17c..0df1dd621f 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -17,11 +17,12 @@
package org.apache.spark.network.shuffle;
+import java.io.File;
+import java.io.IOException;
import java.util.List;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
-import org.apache.spark.network.util.TransportConf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -31,10 +32,10 @@ import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
-import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
-import org.apache.spark.network.shuffle.protocol.OpenBlocks;
-import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
-import org.apache.spark.network.shuffle.protocol.StreamHandle;
+import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId;
+import org.apache.spark.network.shuffle.protocol.*;
+import org.apache.spark.network.util.TransportConf;
+
/**
* RPC Handler for a server which can serve shuffle blocks from outside of an Executor process.
@@ -46,11 +47,13 @@ import org.apache.spark.network.shuffle.protocol.StreamHandle;
public class ExternalShuffleBlockHandler extends RpcHandler {
private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class);
- private final ExternalShuffleBlockResolver blockManager;
+ @VisibleForTesting
+ final ExternalShuffleBlockResolver blockManager;
private final OneForOneStreamManager streamManager;
- public ExternalShuffleBlockHandler(TransportConf conf) {
- this(new OneForOneStreamManager(), new ExternalShuffleBlockResolver(conf));
+ public ExternalShuffleBlockHandler(TransportConf conf, File registeredExecutorFile) throws IOException {
+ this(new OneForOneStreamManager(),
+ new ExternalShuffleBlockResolver(conf, registeredExecutorFile));
}
/** Enables mocking out the StreamManager and BlockManager. */
@@ -105,4 +108,22 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
public void applicationRemoved(String appId, boolean cleanupLocalDirs) {
blockManager.applicationRemoved(appId, cleanupLocalDirs);
}
+
+ /**
+ * Register an (application, executor) with the given shuffle info.
+ *
+ * The "re-" is meant to highlight the intended use of this method -- when this service is
+ * restarted, this is used to restore the state of executors from before the restart. Normal
+ * registration will happen via a message handled in receive()
+ *
+ * @param appExecId
+ * @param executorInfo
+ */
+ public void reregisterExecutor(AppExecId appExecId, ExecutorShuffleInfo executorInfo) {
+ blockManager.registerExecutor(appExecId.appId, appExecId.execId, executorInfo);
+ }
+
+ public void close() {
+ blockManager.close();
+ }
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
index 022ed88a16..79beec4429 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
@@ -17,19 +17,24 @@
package org.apache.spark.network.shuffle;
-import java.io.DataInputStream;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.util.Iterator;
-import java.util.Map;
+import java.io.*;
+import java.util.*;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Charsets;
import com.google.common.base.Objects;
import com.google.common.collect.Maps;
+import org.fusesource.leveldbjni.JniDBFactory;
+import org.fusesource.leveldbjni.internal.NativeDB;
+import org.iq80.leveldb.DB;
+import org.iq80.leveldb.DBIterator;
+import org.iq80.leveldb.Options;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -52,25 +57,87 @@ import org.apache.spark.network.util.TransportConf;
public class ExternalShuffleBlockResolver {
private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class);
+ private static final ObjectMapper mapper = new ObjectMapper();
+ /**
+ * This a common prefix to the key for each app registration we stick in leveldb, so they
+ * are easy to find, since leveldb lets you search based on prefix.
+ */
+ private static final String APP_KEY_PREFIX = "AppExecShuffleInfo";
+ private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0);
+
// Map containing all registered executors' metadata.
- private final ConcurrentMap<AppExecId, ExecutorShuffleInfo> executors;
+ @VisibleForTesting
+ final ConcurrentMap<AppExecId, ExecutorShuffleInfo> executors;
// Single-threaded Java executor used to perform expensive recursive directory deletion.
private final Executor directoryCleaner;
private final TransportConf conf;
- public ExternalShuffleBlockResolver(TransportConf conf) {
- this(conf, Executors.newSingleThreadExecutor(
+ @VisibleForTesting
+ final File registeredExecutorFile;
+ @VisibleForTesting
+ final DB db;
+
+ public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile)
+ throws IOException {
+ this(conf, registeredExecutorFile, Executors.newSingleThreadExecutor(
// Add `spark` prefix because it will run in NM in Yarn mode.
NettyUtils.createThreadFactory("spark-shuffle-directory-cleaner")));
}
// Allows tests to have more control over when directories are cleaned up.
@VisibleForTesting
- ExternalShuffleBlockResolver(TransportConf conf, Executor directoryCleaner) {
+ ExternalShuffleBlockResolver(
+ TransportConf conf,
+ File registeredExecutorFile,
+ Executor directoryCleaner) throws IOException {
this.conf = conf;
- this.executors = Maps.newConcurrentMap();
+ this.registeredExecutorFile = registeredExecutorFile;
+ if (registeredExecutorFile != null) {
+ Options options = new Options();
+ options.createIfMissing(false);
+ options.logger(new LevelDBLogger());
+ DB tmpDb;
+ try {
+ tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options);
+ } catch (NativeDB.DBException e) {
+ if (e.isNotFound() || e.getMessage().contains(" does not exist ")) {
+ logger.info("Creating state database at " + registeredExecutorFile);
+ options.createIfMissing(true);
+ try {
+ tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options);
+ } catch (NativeDB.DBException dbExc) {
+ throw new IOException("Unable to create state store", dbExc);
+ }
+ } else {
+ // the leveldb file seems to be corrupt somehow. Lets just blow it away and create a new
+ // one, so we can keep processing new apps
+ logger.error("error opening leveldb file {}. Creating new file, will not be able to " +
+ "recover state for existing applications", registeredExecutorFile, e);
+ if (registeredExecutorFile.isDirectory()) {
+ for (File f : registeredExecutorFile.listFiles()) {
+ f.delete();
+ }
+ }
+ registeredExecutorFile.delete();
+ options.createIfMissing(true);
+ try {
+ tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options);
+ } catch (NativeDB.DBException dbExc) {
+ throw new IOException("Unable to create state store", dbExc);
+ }
+
+ }
+ }
+ // if there is a version mismatch, we throw an exception, which means the service is unusable
+ checkVersion(tmpDb);
+ executors = reloadRegisteredExecutors(tmpDb);
+ db = tmpDb;
+ } else {
+ db = null;
+ executors = Maps.newConcurrentMap();
+ }
this.directoryCleaner = directoryCleaner;
}
@@ -81,6 +148,15 @@ public class ExternalShuffleBlockResolver {
ExecutorShuffleInfo executorInfo) {
AppExecId fullId = new AppExecId(appId, execId);
logger.info("Registered executor {} with {}", fullId, executorInfo);
+ try {
+ if (db != null) {
+ byte[] key = dbAppExecKey(fullId);
+ byte[] value = mapper.writeValueAsString(executorInfo).getBytes(Charsets.UTF_8);
+ db.put(key, value);
+ }
+ } catch (Exception e) {
+ logger.error("Error saving registered executors", e);
+ }
executors.put(fullId, executorInfo);
}
@@ -136,6 +212,13 @@ public class ExternalShuffleBlockResolver {
// Only touch executors associated with the appId that was removed.
if (appId.equals(fullId.appId)) {
it.remove();
+ if (db != null) {
+ try {
+ db.delete(dbAppExecKey(fullId));
+ } catch (IOException e) {
+ logger.error("Error deleting {} from executor state db", appId, e);
+ }
+ }
if (cleanupLocalDirs) {
logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length);
@@ -220,12 +303,23 @@ public class ExternalShuffleBlockResolver {
return new File(new File(localDir, String.format("%02x", subDirId)), filename);
}
+ void close() {
+ if (db != null) {
+ try {
+ db.close();
+ } catch (IOException e) {
+ logger.error("Exception closing leveldb with registered executors", e);
+ }
+ }
+ }
+
/** Simply encodes an executor's full ID, which is appId + execId. */
- private static class AppExecId {
- final String appId;
- final String execId;
+ public static class AppExecId {
+ public final String appId;
+ public final String execId;
- private AppExecId(String appId, String execId) {
+ @JsonCreator
+ public AppExecId(@JsonProperty("appId") String appId, @JsonProperty("execId") String execId) {
this.appId = appId;
this.execId = execId;
}
@@ -252,4 +346,105 @@ public class ExternalShuffleBlockResolver {
.toString();
}
}
+
+ private static byte[] dbAppExecKey(AppExecId appExecId) throws IOException {
+ // we stick a common prefix on all the keys so we can find them in the DB
+ String appExecJson = mapper.writeValueAsString(appExecId);
+ String key = (APP_KEY_PREFIX + ";" + appExecJson);
+ return key.getBytes(Charsets.UTF_8);
+ }
+
+ private static AppExecId parseDbAppExecKey(String s) throws IOException {
+ if (!s.startsWith(APP_KEY_PREFIX)) {
+ throw new IllegalArgumentException("expected a string starting with " + APP_KEY_PREFIX);
+ }
+ String json = s.substring(APP_KEY_PREFIX.length() + 1);
+ AppExecId parsed = mapper.readValue(json, AppExecId.class);
+ return parsed;
+ }
+
+ @VisibleForTesting
+ static ConcurrentMap<AppExecId, ExecutorShuffleInfo> reloadRegisteredExecutors(DB db)
+ throws IOException {
+ ConcurrentMap<AppExecId, ExecutorShuffleInfo> registeredExecutors = Maps.newConcurrentMap();
+ if (db != null) {
+ DBIterator itr = db.iterator();
+ itr.seek(APP_KEY_PREFIX.getBytes(Charsets.UTF_8));
+ while (itr.hasNext()) {
+ Map.Entry<byte[], byte[]> e = itr.next();
+ String key = new String(e.getKey(), Charsets.UTF_8);
+ if (!key.startsWith(APP_KEY_PREFIX)) {
+ break;
+ }
+ AppExecId id = parseDbAppExecKey(key);
+ ExecutorShuffleInfo shuffleInfo = mapper.readValue(e.getValue(), ExecutorShuffleInfo.class);
+ registeredExecutors.put(id, shuffleInfo);
+ }
+ }
+ return registeredExecutors;
+ }
+
+ private static class LevelDBLogger implements org.iq80.leveldb.Logger {
+ private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class);
+
+ @Override
+ public void log(String message) {
+ LOG.info(message);
+ }
+ }
+
+ /**
+ * Simple major.minor versioning scheme. Any incompatible changes should be across major
+ * versions. Minor version differences are allowed -- meaning we should be able to read
+ * dbs that are either earlier *or* later on the minor version.
+ */
+ private static void checkVersion(DB db) throws IOException {
+ byte[] bytes = db.get(StoreVersion.KEY);
+ if (bytes == null) {
+ storeVersion(db);
+ } else {
+ StoreVersion version = mapper.readValue(bytes, StoreVersion.class);
+ if (version.major != CURRENT_VERSION.major) {
+ throw new IOException("cannot read state DB with version " + version + ", incompatible " +
+ "with current version " + CURRENT_VERSION);
+ }
+ storeVersion(db);
+ }
+ }
+
+ private static void storeVersion(DB db) throws IOException {
+ db.put(StoreVersion.KEY, mapper.writeValueAsBytes(CURRENT_VERSION));
+ }
+
+
+ public static class StoreVersion {
+
+ final static byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8);
+
+ public final int major;
+ public final int minor;
+
+ @JsonCreator public StoreVersion(@JsonProperty("major") int major, @JsonProperty("minor") int minor) {
+ this.major = major;
+ this.minor = minor;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ StoreVersion that = (StoreVersion) o;
+
+ return major == that.major && minor == that.minor;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = major;
+ result = 31 * result + minor;
+ return result;
+ }
+ }
+
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java
index cadc8e8369..102d4efb8b 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java
@@ -19,6 +19,8 @@ package org.apache.spark.network.shuffle.protocol;
import java.util.Arrays;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
@@ -34,7 +36,11 @@ public class ExecutorShuffleInfo implements Encodable {
/** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */
public final String shuffleManager;
- public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) {
+ @JsonCreator
+ public ExecutorShuffleInfo(
+ @JsonProperty("localDirs") String[] localDirs,
+ @JsonProperty("subDirsPerLocalDir") int subDirsPerLocalDir,
+ @JsonProperty("shuffleManager") String shuffleManager) {
this.localDirs = localDirs;
this.subDirsPerLocalDir = subDirsPerLocalDir;
this.shuffleManager = shuffleManager;
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
index d02f4f0fdb..3c6cb367de 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
@@ -21,9 +21,12 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
+import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.io.CharStreams;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
@@ -59,8 +62,8 @@ public class ExternalShuffleBlockResolverSuite {
}
@Test
- public void testBadRequests() {
- ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf);
+ public void testBadRequests() throws IOException {
+ ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null);
// Unregistered executor
try {
resolver.getBlockData("app0", "exec1", "shuffle_1_1_0");
@@ -91,7 +94,7 @@ public class ExternalShuffleBlockResolverSuite {
@Test
public void testSortShuffleBlocks() throws IOException {
- ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf);
+ ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null);
resolver.registerExecutor("app0", "exec0",
dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager"));
@@ -110,7 +113,7 @@ public class ExternalShuffleBlockResolverSuite {
@Test
public void testHashShuffleBlocks() throws IOException {
- ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf);
+ ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null);
resolver.registerExecutor("app0", "exec0",
dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager"));
@@ -126,4 +129,28 @@ public class ExternalShuffleBlockResolverSuite {
block1Stream.close();
assertEquals(hashBlock1, block1);
}
+
+ @Test
+ public void jsonSerializationOfExecutorRegistration() throws IOException {
+ ObjectMapper mapper = new ObjectMapper();
+ AppExecId appId = new AppExecId("foo", "bar");
+ String appIdJson = mapper.writeValueAsString(appId);
+ AppExecId parsedAppId = mapper.readValue(appIdJson, AppExecId.class);
+ assertEquals(parsedAppId, appId);
+
+ ExecutorShuffleInfo shuffleInfo =
+ new ExecutorShuffleInfo(new String[]{"/bippy", "/flippy"}, 7, "hash");
+ String shuffleJson = mapper.writeValueAsString(shuffleInfo);
+ ExecutorShuffleInfo parsedShuffleInfo =
+ mapper.readValue(shuffleJson, ExecutorShuffleInfo.class);
+ assertEquals(parsedShuffleInfo, shuffleInfo);
+
+ // Intentionally keep these hard-coded strings in here, to check backwards-compatability.
+ // its not legacy yet, but keeping this here in case anybody changes it
+ String legacyAppIdJson = "{\"appId\":\"foo\", \"execId\":\"bar\"}";
+ assertEquals(appId, mapper.readValue(legacyAppIdJson, AppExecId.class));
+ String legacyShuffleJson = "{\"localDirs\": [\"/bippy\", \"/flippy\"], " +
+ "\"subDirsPerLocalDir\": 7, \"shuffleManager\": \"hash\"}";
+ assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class));
+ }
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java
index d9d9c1bf2f..2f4f1d0df4 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java
@@ -42,7 +42,7 @@ public class ExternalShuffleCleanupSuite {
TestShuffleDataContext dataContext = createSomeData();
ExternalShuffleBlockResolver resolver =
- new ExternalShuffleBlockResolver(conf, sameThreadExecutor);
+ new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor);
resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr"));
resolver.applicationRemoved("app", false /* cleanup */);
@@ -65,7 +65,8 @@ public class ExternalShuffleCleanupSuite {
@Override public void execute(Runnable runnable) { cleanupCalled.set(true); }
};
- ExternalShuffleBlockResolver manager = new ExternalShuffleBlockResolver(conf, noThreadExecutor);
+ ExternalShuffleBlockResolver manager =
+ new ExternalShuffleBlockResolver(conf, null, noThreadExecutor);
manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr"));
manager.applicationRemoved("app", true);
@@ -83,7 +84,7 @@ public class ExternalShuffleCleanupSuite {
TestShuffleDataContext dataContext1 = createSomeData();
ExternalShuffleBlockResolver resolver =
- new ExternalShuffleBlockResolver(conf, sameThreadExecutor);
+ new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor);
resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr"));
resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr"));
@@ -99,7 +100,7 @@ public class ExternalShuffleCleanupSuite {
TestShuffleDataContext dataContext1 = createSomeData();
ExternalShuffleBlockResolver resolver =
- new ExternalShuffleBlockResolver(conf, sameThreadExecutor);
+ new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor);
resolver.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr"));
resolver.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr"));
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index 39aa49911d..a3f9a38b1a 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -92,7 +92,7 @@ public class ExternalShuffleIntegrationSuite {
dataContext1.insertHashShuffleData(1, 0, exec1Blocks);
conf = new TransportConf(new SystemPropertyConfigProvider());
- handler = new ExternalShuffleBlockHandler(conf);
+ handler = new ExternalShuffleBlockHandler(conf, null);
TransportContext transportContext = new TransportContext(conf, handler);
server = transportContext.createServer();
}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
index d4ec1956c1..aa99efda94 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
@@ -43,8 +43,9 @@ public class ExternalShuffleSecuritySuite {
TransportServer server;
@Before
- public void beforeEach() {
- TransportContext context = new TransportContext(conf, new ExternalShuffleBlockHandler(conf));
+ public void beforeEach() throws IOException {
+ TransportContext context =
+ new TransportContext(conf, new ExternalShuffleBlockHandler(conf, null));
TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf,
new TestSecretKeyHolder("my-app-id", "secret"));
this.server = context.createServer(Arrays.asList(bootstrap));
diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
index 463f99ef33..11ea7f3fd3 100644
--- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
+++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
@@ -17,25 +17,21 @@
package org.apache.spark.network.yarn;
+import java.io.File;
import java.nio.ByteBuffer;
import java.util.List;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ContainerId;
-import org.apache.hadoop.yarn.server.api.AuxiliaryService;
-import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext;
-import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext;
-import org.apache.hadoop.yarn.server.api.ContainerInitializationContext;
-import org.apache.hadoop.yarn.server.api.ContainerTerminationContext;
+import org.apache.hadoop.yarn.server.api.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.sasl.SaslServerBootstrap;
import org.apache.spark.network.sasl.ShuffleSecretManager;
-import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
@@ -79,11 +75,26 @@ public class YarnShuffleService extends AuxiliaryService {
private TransportServer shuffleServer = null;
// Handles registering executors and opening shuffle blocks
- private ExternalShuffleBlockHandler blockHandler;
+ @VisibleForTesting
+ ExternalShuffleBlockHandler blockHandler;
+
+ // Where to store & reload executor info for recovering state after an NM restart
+ @VisibleForTesting
+ File registeredExecutorFile;
+
+ // just for testing when you want to find an open port
+ @VisibleForTesting
+ static int boundPort = -1;
+
+ // just for integration tests that want to look at this file -- in general not sensible as
+ // a static
+ @VisibleForTesting
+ static YarnShuffleService instance;
public YarnShuffleService() {
super("spark_shuffle");
logger.info("Initializing YARN shuffle service for Spark");
+ instance = this;
}
/**
@@ -100,11 +111,24 @@ public class YarnShuffleService extends AuxiliaryService {
*/
@Override
protected void serviceInit(Configuration conf) {
+
+ // In case this NM was killed while there were running spark applications, we need to restore
+ // lost state for the existing executors. We look for an existing file in the NM's local dirs.
+ // If we don't find one, then we choose a file to use to save the state next time. Even if
+ // an application was stopped while the NM was down, we expect yarn to call stopApplication()
+ // when it comes back
+ registeredExecutorFile =
+ findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs"));
+
TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf));
// If authentication is enabled, set up the shuffle server to use a
// special RPC handler that filters out unauthenticated fetch requests
boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE);
- blockHandler = new ExternalShuffleBlockHandler(transportConf);
+ try {
+ blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile);
+ } catch (Exception e) {
+ logger.error("Failed to initialize external shuffle service", e);
+ }
List<TransportServerBootstrap> bootstraps = Lists.newArrayList();
if (authEnabled) {
@@ -116,9 +140,13 @@ public class YarnShuffleService extends AuxiliaryService {
SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT);
TransportContext transportContext = new TransportContext(transportConf, blockHandler);
shuffleServer = transportContext.createServer(port, bootstraps);
+ // the port should normally be fixed, but for tests its useful to find an open port
+ port = shuffleServer.getPort();
+ boundPort = port;
String authEnabledString = authEnabled ? "enabled" : "not enabled";
logger.info("Started YARN shuffle service for Spark on port {}. " +
- "Authentication is {}.", port, authEnabledString);
+ "Authentication is {}. Registered executor file is {}", port, authEnabledString,
+ registeredExecutorFile);
}
@Override
@@ -161,6 +189,16 @@ public class YarnShuffleService extends AuxiliaryService {
logger.info("Stopping container {}", containerId);
}
+ private File findRegisteredExecutorFile(String[] localDirs) {
+ for (String dir: localDirs) {
+ File f = new File(dir, "registeredExecutors.ldb");
+ if (f.exists()) {
+ return f;
+ }
+ }
+ return new File(localDirs[0], "registeredExecutors.ldb");
+ }
+
/**
* Close the shuffle server to clean up any associated state.
*/
@@ -170,6 +208,9 @@ public class YarnShuffleService extends AuxiliaryService {
if (shuffleServer != null) {
shuffleServer.close();
}
+ if (blockHandler != null) {
+ blockHandler.close();
+ }
} catch (Exception e) {
logger.error("Exception when stopping service", e);
}
@@ -180,5 +221,4 @@ public class YarnShuffleService extends AuxiliaryService {
public ByteBuffer getMetaData() {
return ByteBuffer.allocate(0);
}
-
}
diff --git a/pom.xml b/pom.xml
index ccfa1ea19b..d5945f2546 100644
--- a/pom.xml
+++ b/pom.xml
@@ -655,6 +655,11 @@
<artifactId>jackson-databind</artifactId>
<version>${fasterxml.jackson.version}</version>
</dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-annotations</artifactId>
+ <version>${fasterxml.jackson.version}</version>
+ </dependency>
<!-- Guava is excluded because of SPARK-6149. The Guava version referenced in this module is
15.0, which causes runtime incompatibility issues. -->
<dependency>
diff --git a/yarn/pom.xml b/yarn/pom.xml
index 15db54e4e7..f673769530 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -40,6 +40,12 @@
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
+ <artifactId>spark-network-yarn_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
new file mode 100644
index 0000000000..128e996b71
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.{File, FileOutputStream, OutputStreamWriter}
+import java.util.Properties
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConversions._
+
+import com.google.common.base.Charsets.UTF_8
+import com.google.common.io.Files
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.server.MiniYARNCluster
+import org.scalatest.{BeforeAndAfterAll, Matchers}
+
+import org.apache.spark._
+import org.apache.spark.util.Utils
+
+abstract class BaseYarnClusterSuite
+ extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging {
+
+ // log4j configuration for the YARN containers, so that their output is collected
+ // by YARN instead of trying to overwrite unit-tests.log.
+ protected val LOG4J_CONF = """
+ |log4j.rootCategory=DEBUG, console
+ |log4j.appender.console=org.apache.log4j.ConsoleAppender
+ |log4j.appender.console.target=System.err
+ |log4j.appender.console.layout=org.apache.log4j.PatternLayout
+ |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
+ """.stripMargin
+
+ private var yarnCluster: MiniYARNCluster = _
+ protected var tempDir: File = _
+ private var fakeSparkJar: File = _
+ private var hadoopConfDir: File = _
+ private var logConfDir: File = _
+
+
+ def yarnConfig: YarnConfiguration
+
+ override def beforeAll() {
+ super.beforeAll()
+
+ tempDir = Utils.createTempDir()
+ logConfDir = new File(tempDir, "log4j")
+ logConfDir.mkdir()
+ System.setProperty("SPARK_YARN_MODE", "true")
+
+ val logConfFile = new File(logConfDir, "log4j.properties")
+ Files.write(LOG4J_CONF, logConfFile, UTF_8)
+
+ yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1)
+ yarnCluster.init(yarnConfig)
+ yarnCluster.start()
+
+ // There's a race in MiniYARNCluster in which start() may return before the RM has updated
+ // its address in the configuration. You can see this in the logs by noticing that when
+ // MiniYARNCluster prints the address, it still has port "0" assigned, although later the
+ // test works sometimes:
+ //
+ // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0
+ //
+ // That log message prints the contents of the RM_ADDRESS config variable. If you check it
+ // later on, it looks something like this:
+ //
+ // INFO YarnClusterSuite: RM address in configuration is blah:42631
+ //
+ // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't
+ // done so in a timely manner (defined to be 10 seconds).
+ val config = yarnCluster.getConfig()
+ val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10)
+ while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") {
+ if (System.currentTimeMillis() > deadline) {
+ throw new IllegalStateException("Timed out waiting for RM to come up.")
+ }
+ logDebug("RM address still not set in configuration, waiting...")
+ TimeUnit.MILLISECONDS.sleep(100)
+ }
+
+ logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}")
+
+ fakeSparkJar = File.createTempFile("sparkJar", null, tempDir)
+ hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR)
+ assert(hadoopConfDir.mkdir())
+ File.createTempFile("token", ".txt", hadoopConfDir)
+ }
+
+ override def afterAll() {
+ yarnCluster.stop()
+ System.clearProperty("SPARK_YARN_MODE")
+ super.afterAll()
+ }
+
+ protected def runSpark(
+ clientMode: Boolean,
+ klass: String,
+ appArgs: Seq[String] = Nil,
+ sparkArgs: Seq[String] = Nil,
+ extraClassPath: Seq[String] = Nil,
+ extraJars: Seq[String] = Nil,
+ extraConf: Map[String, String] = Map()): Unit = {
+ val master = if (clientMode) "yarn-client" else "yarn-cluster"
+ val props = new Properties()
+
+ props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath())
+
+ val childClasspath = logConfDir.getAbsolutePath() +
+ File.pathSeparator +
+ sys.props("java.class.path") +
+ File.pathSeparator +
+ extraClassPath.mkString(File.pathSeparator)
+ props.setProperty("spark.driver.extraClassPath", childClasspath)
+ props.setProperty("spark.executor.extraClassPath", childClasspath)
+
+ // SPARK-4267: make sure java options are propagated correctly.
+ props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"")
+ props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"")
+
+ yarnCluster.getConfig().foreach { e =>
+ props.setProperty("spark.hadoop." + e.getKey(), e.getValue())
+ }
+
+ sys.props.foreach { case (k, v) =>
+ if (k.startsWith("spark.")) {
+ props.setProperty(k, v)
+ }
+ }
+
+ extraConf.foreach { case (k, v) => props.setProperty(k, v) }
+
+ val propsFile = File.createTempFile("spark", ".properties", tempDir)
+ val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8)
+ props.store(writer, "Spark properties.")
+ writer.close()
+
+ val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil
+ val mainArgs =
+ if (klass.endsWith(".py")) {
+ Seq(klass)
+ } else {
+ Seq("--class", klass, fakeSparkJar.getAbsolutePath())
+ }
+ val argv =
+ Seq(
+ new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(),
+ "--master", master,
+ "--num-executors", "1",
+ "--properties-file", propsFile.getAbsolutePath()) ++
+ extraJarArgs ++
+ sparkArgs ++
+ mainArgs ++
+ appArgs
+
+ Utils.executeAndGetOutput(argv,
+ extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()))
+ }
+
+ /**
+ * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide
+ * any sort of error when the job process finishes successfully, but the job itself fails. So
+ * the tests enforce that something is written to a file after everything is ok to indicate
+ * that the job succeeded.
+ */
+ protected def checkResult(result: File): Unit = {
+ checkResult(result, "success")
+ }
+
+ protected def checkResult(result: File, expected: String): Unit = {
+ val resultString = Files.toString(result, UTF_8)
+ resultString should be (expected)
+ }
+
+ protected def mainClassName(klass: Class[_]): String = {
+ klass.getName().stripSuffix("$")
+ }
+
+}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index eb6e1fd370..128350b648 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -17,25 +17,20 @@
package org.apache.spark.deploy.yarn
-import java.io.{File, FileOutputStream, OutputStreamWriter}
+import java.io.File
import java.net.URL
-import java.util.Properties
-import java.util.concurrent.TimeUnit
-import scala.collection.JavaConversions._
import scala.collection.mutable
+import scala.collection.JavaConversions._
import com.google.common.base.Charsets.UTF_8
-import com.google.common.io.ByteStreams
-import com.google.common.io.Files
+import com.google.common.io.{ByteStreams, Files}
import org.apache.hadoop.yarn.conf.YarnConfiguration
-import org.apache.hadoop.yarn.server.MiniYARNCluster
-import org.scalatest.{BeforeAndAfterAll, Matchers}
+import org.scalatest.Matchers
import org.apache.spark._
+import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded}
import org.apache.spark.scheduler.cluster.ExecutorInfo
-import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart,
- SparkListenerExecutorAdded}
import org.apache.spark.util.Utils
/**
@@ -43,17 +38,9 @@ import org.apache.spark.util.Utils
* applications, and require the Spark assembly to be built before they can be successfully
* run.
*/
-class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging {
-
- // log4j configuration for the YARN containers, so that their output is collected
- // by YARN instead of trying to overwrite unit-tests.log.
- private val LOG4J_CONF = """
- |log4j.rootCategory=DEBUG, console
- |log4j.appender.console=org.apache.log4j.ConsoleAppender
- |log4j.appender.console.target=System.err
- |log4j.appender.console.layout=org.apache.log4j.PatternLayout
- |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
- """.stripMargin
+class YarnClusterSuite extends BaseYarnClusterSuite {
+
+ override def yarnConfig: YarnConfiguration = new YarnConfiguration()
private val TEST_PYFILE = """
|import mod1, mod2
@@ -82,65 +69,6 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher
| return 42
""".stripMargin
- private var yarnCluster: MiniYARNCluster = _
- private var tempDir: File = _
- private var fakeSparkJar: File = _
- private var hadoopConfDir: File = _
- private var logConfDir: File = _
-
- override def beforeAll() {
- super.beforeAll()
-
- tempDir = Utils.createTempDir()
- logConfDir = new File(tempDir, "log4j")
- logConfDir.mkdir()
- System.setProperty("SPARK_YARN_MODE", "true")
-
- val logConfFile = new File(logConfDir, "log4j.properties")
- Files.write(LOG4J_CONF, logConfFile, UTF_8)
-
- yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1)
- yarnCluster.init(new YarnConfiguration())
- yarnCluster.start()
-
- // There's a race in MiniYARNCluster in which start() may return before the RM has updated
- // its address in the configuration. You can see this in the logs by noticing that when
- // MiniYARNCluster prints the address, it still has port "0" assigned, although later the
- // test works sometimes:
- //
- // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0
- //
- // That log message prints the contents of the RM_ADDRESS config variable. If you check it
- // later on, it looks something like this:
- //
- // INFO YarnClusterSuite: RM address in configuration is blah:42631
- //
- // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't
- // done so in a timely manner (defined to be 10 seconds).
- val config = yarnCluster.getConfig()
- val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10)
- while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") {
- if (System.currentTimeMillis() > deadline) {
- throw new IllegalStateException("Timed out waiting for RM to come up.")
- }
- logDebug("RM address still not set in configuration, waiting...")
- TimeUnit.MILLISECONDS.sleep(100)
- }
-
- logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}")
-
- fakeSparkJar = File.createTempFile("sparkJar", null, tempDir)
- hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR)
- assert(hadoopConfDir.mkdir())
- File.createTempFile("token", ".txt", hadoopConfDir)
- }
-
- override def afterAll() {
- yarnCluster.stop()
- System.clearProperty("SPARK_YARN_MODE")
- super.afterAll()
- }
-
test("run Spark in yarn-client mode") {
testBasicYarnApp(true)
}
@@ -174,7 +102,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher
}
private def testBasicYarnApp(clientMode: Boolean): Unit = {
- var result = File.createTempFile("result", null, tempDir)
+ val result = File.createTempFile("result", null, tempDir)
runSpark(clientMode, mainClassName(YarnClusterDriver.getClass),
appArgs = Seq(result.getAbsolutePath()))
checkResult(result)
@@ -224,89 +152,6 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher
checkResult(executorResult, "OVERRIDDEN")
}
- private def runSpark(
- clientMode: Boolean,
- klass: String,
- appArgs: Seq[String] = Nil,
- sparkArgs: Seq[String] = Nil,
- extraClassPath: Seq[String] = Nil,
- extraJars: Seq[String] = Nil,
- extraConf: Map[String, String] = Map()): Unit = {
- val master = if (clientMode) "yarn-client" else "yarn-cluster"
- val props = new Properties()
-
- props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath())
-
- val childClasspath = logConfDir.getAbsolutePath() +
- File.pathSeparator +
- sys.props("java.class.path") +
- File.pathSeparator +
- extraClassPath.mkString(File.pathSeparator)
- props.setProperty("spark.driver.extraClassPath", childClasspath)
- props.setProperty("spark.executor.extraClassPath", childClasspath)
-
- // SPARK-4267: make sure java options are propagated correctly.
- props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"")
- props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"")
-
- yarnCluster.getConfig().foreach { e =>
- props.setProperty("spark.hadoop." + e.getKey(), e.getValue())
- }
-
- sys.props.foreach { case (k, v) =>
- if (k.startsWith("spark.")) {
- props.setProperty(k, v)
- }
- }
-
- extraConf.foreach { case (k, v) => props.setProperty(k, v) }
-
- val propsFile = File.createTempFile("spark", ".properties", tempDir)
- val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8)
- props.store(writer, "Spark properties.")
- writer.close()
-
- val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil
- val mainArgs =
- if (klass.endsWith(".py")) {
- Seq(klass)
- } else {
- Seq("--class", klass, fakeSparkJar.getAbsolutePath())
- }
- val argv =
- Seq(
- new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(),
- "--master", master,
- "--num-executors", "1",
- "--properties-file", propsFile.getAbsolutePath()) ++
- extraJarArgs ++
- sparkArgs ++
- mainArgs ++
- appArgs
-
- Utils.executeAndGetOutput(argv,
- extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()))
- }
-
- /**
- * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide
- * any sort of error when the job process finishes successfully, but the job itself fails. So
- * the tests enforce that something is written to a file after everything is ok to indicate
- * that the job succeeded.
- */
- private def checkResult(result: File): Unit = {
- checkResult(result, "success")
- }
-
- private def checkResult(result: File, expected: String): Unit = {
- var resultString = Files.toString(result, UTF_8)
- resultString should be (expected)
- }
-
- private def mainClassName(klass: Class[_]): String = {
- klass.getName().stripSuffix("$")
- }
-
}
private[spark] class SaveExecutorInfo extends SparkListener {
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
new file mode 100644
index 0000000000..5e8238822b
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
@@ -0,0 +1,109 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.deploy.yarn
+
+import java.io.File
+
+import com.google.common.base.Charsets.UTF_8
+import com.google.common.io.Files
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.scalatest.Matchers
+
+import org.apache.spark._
+import org.apache.spark.network.shuffle.ShuffleTestAccessor
+import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor}
+
+/**
+ * Integration test for the external shuffle service with a yarn mini-cluster
+ */
+class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite {
+
+ override def yarnConfig: YarnConfiguration = {
+ val yarnConfig = new YarnConfiguration()
+ yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle")
+ yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"),
+ classOf[YarnShuffleService].getCanonicalName)
+ yarnConfig.set("spark.shuffle.service.port", "0")
+ yarnConfig
+ }
+
+ test("external shuffle service") {
+ val shuffleServicePort = YarnTestAccessor.getShuffleServicePort
+ val shuffleService = YarnTestAccessor.getShuffleServiceInstance
+
+ val registeredExecFile = YarnTestAccessor.getRegisteredExecutorFile(shuffleService)
+
+ logInfo("Shuffle service port = " + shuffleServicePort)
+ val result = File.createTempFile("result", null, tempDir)
+ runSpark(
+ false,
+ mainClassName(YarnExternalShuffleDriver.getClass),
+ appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath),
+ extraConf = Map(
+ "spark.shuffle.service.enabled" -> "true",
+ "spark.shuffle.service.port" -> shuffleServicePort.toString
+ )
+ )
+ checkResult(result)
+ assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists())
+ }
+}
+
+private object YarnExternalShuffleDriver extends Logging with Matchers {
+
+ val WAIT_TIMEOUT_MILLIS = 10000
+
+ def main(args: Array[String]): Unit = {
+ if (args.length != 2) {
+ // scalastyle:off println
+ System.err.println(
+ s"""
+ |Invalid command line: ${args.mkString(" ")}
+ |
+ |Usage: ExternalShuffleDriver [result file] [registed exec file]
+ """.stripMargin)
+ // scalastyle:on println
+ System.exit(1)
+ }
+
+ val sc = new SparkContext(new SparkConf()
+ .setAppName("External Shuffle Test"))
+ val conf = sc.getConf
+ val status = new File(args(0))
+ val registeredExecFile = new File(args(1))
+ logInfo("shuffle service executor file = " + registeredExecFile)
+ var result = "failure"
+ val execStateCopy = new File(registeredExecFile.getAbsolutePath + "_dup")
+ try {
+ val data = sc.parallelize(0 until 100, 10).map { x => (x % 10) -> x }.reduceByKey{ _ + _ }.
+ collect().toSet
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ data should be ((0 until 10).map{x => x -> (x * 10 + 450)}.toSet)
+ result = "success"
+ // only one process can open a leveldb file at a time, so we copy the files
+ FileUtils.copyDirectory(registeredExecFile, execStateCopy)
+ assert(!ShuffleTestAccessor.reloadRegisteredExecutors(execStateCopy).isEmpty)
+ } finally {
+ sc.stop()
+ FileUtils.deleteDirectory(execStateCopy)
+ Files.write(result, status, UTF_8)
+ }
+ }
+
+}
diff --git a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala
new file mode 100644
index 0000000000..aa46ec5100
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.shuffle
+
+import java.io.{IOException, File}
+import java.util.concurrent.ConcurrentMap
+
+import com.google.common.annotations.VisibleForTesting
+import org.apache.hadoop.yarn.api.records.ApplicationId
+import org.fusesource.leveldbjni.JniDBFactory
+import org.iq80.leveldb.{DB, Options}
+
+import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
+
+/**
+ * just a cheat to get package-visible members in tests
+ */
+object ShuffleTestAccessor {
+
+ def getBlockResolver(handler: ExternalShuffleBlockHandler): ExternalShuffleBlockResolver = {
+ handler.blockManager
+ }
+
+ def getExecutorInfo(
+ appId: ApplicationId,
+ execId: String,
+ resolver: ExternalShuffleBlockResolver
+ ): Option[ExecutorShuffleInfo] = {
+ val id = new AppExecId(appId.toString, execId)
+ Option(resolver.executors.get(id))
+ }
+
+ def registeredExecutorFile(resolver: ExternalShuffleBlockResolver): File = {
+ resolver.registeredExecutorFile
+ }
+
+ def shuffleServiceLevelDB(resolver: ExternalShuffleBlockResolver): DB = {
+ resolver.db
+ }
+
+ def reloadRegisteredExecutors(
+ file: File): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = {
+ val options: Options = new Options
+ options.createIfMissing(true)
+ val factory = new JniDBFactory
+ val db = factory.open(file, options)
+ val result = ExternalShuffleBlockResolver.reloadRegisteredExecutors(db)
+ db.close()
+ result
+ }
+
+ def reloadRegisteredExecutors(
+ db: DB): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = {
+ ExternalShuffleBlockResolver.reloadRegisteredExecutors(db)
+ }
+}
diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
new file mode 100644
index 0000000000..2f22cbdbea
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.yarn
+
+import java.io.{DataOutputStream, File, FileOutputStream}
+
+import scala.annotation.tailrec
+
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.yarn.api.records.ApplicationId
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.server.api.{ApplicationInitializationContext, ApplicationTerminationContext}
+import org.scalatest.{BeforeAndAfterEach, Matchers}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.network.shuffle.ShuffleTestAccessor
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
+
+class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach {
+ private[yarn] var yarnConfig: YarnConfiguration = new YarnConfiguration
+
+ override def beforeEach(): Unit = {
+ yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle")
+ yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"),
+ classOf[YarnShuffleService].getCanonicalName)
+
+ yarnConfig.get("yarn.nodemanager.local-dirs").split(",").foreach { dir =>
+ val d = new File(dir)
+ if (d.exists()) {
+ FileUtils.deleteDirectory(d)
+ }
+ FileUtils.forceMkdir(d)
+ logInfo(s"creating yarn.nodemanager.local-dirs: $d")
+ }
+ }
+
+ var s1: YarnShuffleService = null
+ var s2: YarnShuffleService = null
+ var s3: YarnShuffleService = null
+
+ override def afterEach(): Unit = {
+ if (s1 != null) {
+ s1.stop()
+ s1 = null
+ }
+ if (s2 != null) {
+ s2.stop()
+ s2 = null
+ }
+ if (s3 != null) {
+ s3.stop()
+ s3 = null
+ }
+ }
+
+ test("executor state kept across NM restart") {
+ s1 = new YarnShuffleService
+ s1.init(yarnConfig)
+ val app1Id = ApplicationId.newInstance(0, 1)
+ val app1Data: ApplicationInitializationContext =
+ new ApplicationInitializationContext("user", app1Id, null)
+ s1.initializeApplication(app1Data)
+ val app2Id = ApplicationId.newInstance(0, 2)
+ val app2Data: ApplicationInitializationContext =
+ new ApplicationInitializationContext("user", app2Id, null)
+ s1.initializeApplication(app2Data)
+
+ val execStateFile = s1.registeredExecutorFile
+ execStateFile should not be (null)
+ val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort")
+ val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash")
+
+ val blockHandler = s1.blockHandler
+ val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler)
+ ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile)
+
+ blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1)
+ blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2)
+ ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", blockResolver) should
+ be (Some(shuffleInfo1))
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", blockResolver) should
+ be (Some(shuffleInfo2))
+
+ if (!execStateFile.exists()) {
+ @tailrec def findExistingParent(file: File): File = {
+ if (file == null) file
+ else if (file.exists()) file
+ else findExistingParent(file.getParentFile())
+ }
+ val existingParent = findExistingParent(execStateFile)
+ assert(false, s"$execStateFile does not exist -- closest existing parent is $existingParent")
+ }
+ assert(execStateFile.exists(), s"$execStateFile did not exist")
+
+ // now we pretend the shuffle service goes down, and comes back up
+ s1.stop()
+ s2 = new YarnShuffleService
+ s2.init(yarnConfig)
+ s2.registeredExecutorFile should be (execStateFile)
+
+ val handler2 = s2.blockHandler
+ val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2)
+
+ // now we reinitialize only one of the apps, and expect yarn to tell us that app2 was stopped
+ // during the restart
+ s2.initializeApplication(app1Data)
+ s2.stopApplication(new ApplicationTerminationContext(app2Id))
+ ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver2) should be (Some(shuffleInfo1))
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (None)
+
+ // Act like the NM restarts one more time
+ s2.stop()
+ s3 = new YarnShuffleService
+ s3.init(yarnConfig)
+ s3.registeredExecutorFile should be (execStateFile)
+
+ val handler3 = s3.blockHandler
+ val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3)
+
+ // app1 is still running
+ s3.initializeApplication(app1Data)
+ ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver3) should be (Some(shuffleInfo1))
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (None)
+ s3.stop()
+ }
+
+ test("removed applications should not be in registered executor file") {
+ s1 = new YarnShuffleService
+ s1.init(yarnConfig)
+ val app1Id = ApplicationId.newInstance(0, 1)
+ val app1Data: ApplicationInitializationContext =
+ new ApplicationInitializationContext("user", app1Id, null)
+ s1.initializeApplication(app1Data)
+ val app2Id = ApplicationId.newInstance(0, 2)
+ val app2Data: ApplicationInitializationContext =
+ new ApplicationInitializationContext("user", app2Id, null)
+ s1.initializeApplication(app2Data)
+
+ val execStateFile = s1.registeredExecutorFile
+ execStateFile should not be (null)
+ val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort")
+ val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash")
+
+ val blockHandler = s1.blockHandler
+ val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler)
+ ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile)
+
+ blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1)
+ blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2)
+
+ val db = ShuffleTestAccessor.shuffleServiceLevelDB(blockResolver)
+ ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty
+
+ s1.stopApplication(new ApplicationTerminationContext(app1Id))
+ ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty
+ s1.stopApplication(new ApplicationTerminationContext(app2Id))
+ ShuffleTestAccessor.reloadRegisteredExecutors(db) shouldBe empty
+ }
+
+ test("shuffle service should be robust to corrupt registered executor file") {
+ s1 = new YarnShuffleService
+ s1.init(yarnConfig)
+ val app1Id = ApplicationId.newInstance(0, 1)
+ val app1Data: ApplicationInitializationContext =
+ new ApplicationInitializationContext("user", app1Id, null)
+ s1.initializeApplication(app1Data)
+
+ val execStateFile = s1.registeredExecutorFile
+ val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort")
+
+ val blockHandler = s1.blockHandler
+ val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler)
+ ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile)
+
+ blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1)
+
+ // now we pretend the shuffle service goes down, and comes back up. But we'll also
+ // make a corrupt registeredExecutor File
+ s1.stop()
+
+ execStateFile.listFiles().foreach{_.delete()}
+
+ val out = new DataOutputStream(new FileOutputStream(execStateFile + "/CURRENT"))
+ out.writeInt(42)
+ out.close()
+
+ s2 = new YarnShuffleService
+ s2.init(yarnConfig)
+ s2.registeredExecutorFile should be (execStateFile)
+
+ val handler2 = s2.blockHandler
+ val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2)
+
+ // we re-initialize app1, but since the file was corrupt there is nothing we can do about it ...
+ s2.initializeApplication(app1Data)
+ // however, when we initialize a totally new app2, everything is still happy
+ val app2Id = ApplicationId.newInstance(0, 2)
+ val app2Data: ApplicationInitializationContext =
+ new ApplicationInitializationContext("user", app2Id, null)
+ s2.initializeApplication(app2Data)
+ val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash")
+ resolver2.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2)
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (Some(shuffleInfo2))
+ s2.stop()
+
+ // another stop & restart should be fine though (eg., we recover from previous corruption)
+ s3 = new YarnShuffleService
+ s3.init(yarnConfig)
+ s3.registeredExecutorFile should be (execStateFile)
+ val handler3 = s3.blockHandler
+ val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3)
+
+ s3.initializeApplication(app2Data)
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (Some(shuffleInfo2))
+ s3.stop()
+
+ }
+
+}
diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala
new file mode 100644
index 0000000000..db322cd18e
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.yarn
+
+import java.io.File
+
+/**
+ * just a cheat to get package-visible members in tests
+ */
+object YarnTestAccessor {
+ def getShuffleServicePort: Int = {
+ YarnShuffleService.boundPort
+ }
+
+ def getShuffleServiceInstance: YarnShuffleService = {
+ YarnShuffleService.instance
+ }
+
+ def getRegisteredExecutorFile(service: YarnShuffleService): File = {
+ service.registeredExecutorFile
+ }
+
+}