aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java76
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java5
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java14
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java53
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala87
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala10
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala3
7 files changed, 195 insertions, 53 deletions
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
index 675820308b..2add9c83a7 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
@@ -19,7 +19,12 @@ package org.apache.spark.network.shuffle.mesos;
import java.io.IOException;
import java.nio.ByteBuffer;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -41,6 +46,13 @@ import org.apache.spark.network.util.TransportConf;
public class MesosExternalShuffleClient extends ExternalShuffleClient {
private final Logger logger = LoggerFactory.getLogger(MesosExternalShuffleClient.class);
+ private final ScheduledExecutorService heartbeaterThread =
+ Executors.newSingleThreadScheduledExecutor(
+ new ThreadFactoryBuilder()
+ .setDaemon(true)
+ .setNameFormat("mesos-external-shuffle-client-heartbeater")
+ .build());
+
/**
* Creates an Mesos external shuffle client that wraps the {@link ExternalShuffleClient}.
* Please refer to docs on {@link ExternalShuffleClient} for more information.
@@ -53,21 +65,59 @@ public class MesosExternalShuffleClient extends ExternalShuffleClient {
super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled);
}
- public void registerDriverWithShuffleService(String host, int port) throws IOException {
+ public void registerDriverWithShuffleService(
+ String host,
+ int port,
+ long heartbeatTimeoutMs,
+ long heartbeatIntervalMs) throws IOException {
+
checkInit();
- ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer();
+ ByteBuffer registerDriver = new RegisterDriver(appId, heartbeatTimeoutMs).toByteBuffer();
TransportClient client = clientFactory.createClient(host, port);
- client.sendRpc(registerDriver, new RpcResponseCallback() {
- @Override
- public void onSuccess(ByteBuffer response) {
- logger.info("Successfully registered app " + appId + " with external shuffle service.");
- }
-
- @Override
- public void onFailure(Throwable e) {
- logger.warn("Unable to register app " + appId + " with external shuffle service. " +
+ client.sendRpc(registerDriver, new RegisterDriverCallback(client, heartbeatIntervalMs));
+ }
+
+ private class RegisterDriverCallback implements RpcResponseCallback {
+ private final TransportClient client;
+ private final long heartbeatIntervalMs;
+
+ private RegisterDriverCallback(TransportClient client, long heartbeatIntervalMs) {
+ this.client = client;
+ this.heartbeatIntervalMs = heartbeatIntervalMs;
+ }
+
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ heartbeaterThread.scheduleAtFixedRate(
+ new Heartbeater(client), 0, heartbeatIntervalMs, TimeUnit.MILLISECONDS);
+ logger.info("Successfully registered app " + appId + " with external shuffle service.");
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ logger.warn("Unable to register app " + appId + " with external shuffle service. " +
"Please manually remove shuffle data after driver exit. Error: " + e);
- }
- });
+ }
+ }
+
+ @Override
+ public void close() {
+ heartbeaterThread.shutdownNow();
+ super.close();
+ }
+
+ private class Heartbeater implements Runnable {
+
+ private final TransportClient client;
+
+ private Heartbeater(TransportClient client) {
+ this.client = client;
+ }
+
+ @Override
+ public void run() {
+ // TODO: Stop sending heartbeats if the shuffle service has lost the app due to timeout
+ client.send(new ShuffleServiceHeartbeat(appId).toByteBuffer());
+ }
}
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
index 7fbe3384b4..21c0ff4136 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
@@ -24,6 +24,7 @@ import io.netty.buffer.Unpooled;
import org.apache.spark.network.protocol.Encodable;
import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver;
+import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat;
/**
* Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or
@@ -40,7 +41,8 @@ public abstract class BlockTransferMessage implements Encodable {
/** Preceding every serialized message is its type, which allows us to deserialize it. */
public static enum Type {
- OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4);
+ OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4),
+ HEARTBEAT(5);
private final byte id;
@@ -64,6 +66,7 @@ public abstract class BlockTransferMessage implements Encodable {
case 2: return RegisterExecutor.decode(buf);
case 3: return StreamHandle.decode(buf);
case 4: return RegisterDriver.decode(buf);
+ case 5: return ShuffleServiceHeartbeat.decode(buf);
default: throw new IllegalArgumentException("Unknown message type: " + type);
}
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java
index eeb0019411..d5f53ccb7f 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java
@@ -31,29 +31,34 @@ import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Typ
*/
public class RegisterDriver extends BlockTransferMessage {
private final String appId;
+ private final long heartbeatTimeoutMs;
- public RegisterDriver(String appId) {
+ public RegisterDriver(String appId, long heartbeatTimeoutMs) {
this.appId = appId;
+ this.heartbeatTimeoutMs = heartbeatTimeoutMs;
}
public String getAppId() { return appId; }
+ public long getHeartbeatTimeoutMs() { return heartbeatTimeoutMs; }
+
@Override
protected Type type() { return Type.REGISTER_DRIVER; }
@Override
public int encodedLength() {
- return Encoders.Strings.encodedLength(appId);
+ return Encoders.Strings.encodedLength(appId) + Long.SIZE / Byte.SIZE;
}
@Override
public void encode(ByteBuf buf) {
Encoders.Strings.encode(buf, appId);
+ buf.writeLong(heartbeatTimeoutMs);
}
@Override
public int hashCode() {
- return Objects.hashCode(appId);
+ return Objects.hashCode(appId, heartbeatTimeoutMs);
}
@Override
@@ -66,6 +71,7 @@ public class RegisterDriver extends BlockTransferMessage {
public static RegisterDriver decode(ByteBuf buf) {
String appId = Encoders.Strings.decode(buf);
- return new RegisterDriver(appId);
+ long heartbeatTimeout = buf.readLong();
+ return new RegisterDriver(appId, heartbeatTimeout);
}
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java
new file mode 100644
index 0000000000..b30bb9aed5
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol.mesos;
+
+import io.netty.buffer.ByteBuf;
+import org.apache.spark.network.protocol.Encoders;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+/**
+ * A heartbeat sent from the driver to the MesosExternalShuffleService.
+ */
+public class ShuffleServiceHeartbeat extends BlockTransferMessage {
+ private final String appId;
+
+ public ShuffleServiceHeartbeat(String appId) {
+ this.appId = appId;
+ }
+
+ public String getAppId() { return appId; }
+
+ @Override
+ protected Type type() { return Type.HEARTBEAT; }
+
+ @Override
+ public int encodedLength() { return Encoders.Strings.encodedLength(appId); }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ }
+
+ public static ShuffleServiceHeartbeat decode(ByteBuf buf) {
+ return new ShuffleServiceHeartbeat(Encoders.Strings.decode(buf));
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
index 4172d924c8..c0f9129a42 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
@@ -17,69 +17,89 @@
package org.apache.spark.deploy.mesos
-import java.net.SocketAddress
import java.nio.ByteBuffer
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
-import scala.collection.mutable
+import scala.collection.JavaConverters._
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.ExternalShuffleService
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage
-import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver
+import org.apache.spark.network.shuffle.protocol.mesos.{RegisterDriver, ShuffleServiceHeartbeat}
import org.apache.spark.network.util.TransportConf
+import org.apache.spark.util.ThreadUtils
/**
* An RPC endpoint that receives registration requests from Spark drivers running on Mesos.
* It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]].
*/
-private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf)
+private[mesos] class MesosExternalShuffleBlockHandler(
+ transportConf: TransportConf,
+ cleanerIntervalS: Long)
extends ExternalShuffleBlockHandler(transportConf, null) with Logging {
- // Stores a map of driver socket addresses to app ids
- private val connectedApps = new mutable.HashMap[SocketAddress, String]
+ ThreadUtils.newDaemonSingleThreadScheduledExecutor("shuffle-cleaner-watcher")
+ .scheduleAtFixedRate(new CleanerThread(), 0, cleanerIntervalS, TimeUnit.SECONDS)
+
+ // Stores a map of app id to app state (timeout value and last heartbeat)
+ private val connectedApps = new ConcurrentHashMap[String, AppState]()
protected override def handleMessage(
message: BlockTransferMessage,
client: TransportClient,
callback: RpcResponseCallback): Unit = {
message match {
- case RegisterDriverParam(appId) =>
+ case RegisterDriverParam(appId, appState) =>
val address = client.getSocketAddress
- logDebug(s"Received registration request from app $appId (remote address $address).")
- if (connectedApps.contains(address)) {
- val existingAppId = connectedApps(address)
- if (!existingAppId.equals(appId)) {
- logError(s"A new app '$appId' has connected to existing address $address, " +
- s"removing previously registered app '$existingAppId'.")
- applicationRemoved(existingAppId, true)
- }
+ val timeout = appState.heartbeatTimeout
+ logInfo(s"Received registration request from app $appId (remote address $address, " +
+ s"heartbeat timeout $timeout ms).")
+ if (connectedApps.containsKey(appId)) {
+ logWarning(s"Received a registration request from app $appId, but it was already " +
+ s"registered")
}
- connectedApps(address) = appId
+ connectedApps.put(appId, appState)
callback.onSuccess(ByteBuffer.allocate(0))
+ case Heartbeat(appId) =>
+ val address = client.getSocketAddress
+ Option(connectedApps.get(appId)) match {
+ case Some(existingAppState) =>
+ logTrace(s"Received ShuffleServiceHeartbeat from app '$appId' (remote " +
+ s"address $address).")
+ existingAppState.lastHeartbeat = System.nanoTime()
+ case None =>
+ logWarning(s"Received ShuffleServiceHeartbeat from an unknown app (remote " +
+ s"address $address, appId '$appId').")
+ }
case _ => super.handleMessage(message, client, callback)
}
}
- /**
- * On connection termination, clean up shuffle files written by the associated application.
- */
- override def channelInactive(client: TransportClient): Unit = {
- val address = client.getSocketAddress
- if (connectedApps.contains(address)) {
- val appId = connectedApps(address)
- logInfo(s"Application $appId disconnected (address was $address).")
- applicationRemoved(appId, true /* cleanupLocalDirs */)
- connectedApps.remove(address)
- } else {
- logWarning(s"Unknown $address disconnected.")
- }
- }
-
/** An extractor object for matching [[RegisterDriver]] message. */
private object RegisterDriverParam {
- def unapply(r: RegisterDriver): Option[String] = Some(r.getAppId)
+ def unapply(r: RegisterDriver): Option[(String, AppState)] =
+ Some((r.getAppId, new AppState(r.getHeartbeatTimeoutMs, System.nanoTime())))
+ }
+
+ private object Heartbeat {
+ def unapply(h: ShuffleServiceHeartbeat): Option[String] = Some(h.getAppId)
+ }
+
+ private class AppState(val heartbeatTimeout: Long, @volatile var lastHeartbeat: Long)
+
+ private class CleanerThread extends Runnable {
+ override def run(): Unit = {
+ val now = System.nanoTime()
+ connectedApps.asScala.foreach { case (appId, appState) =>
+ if (now - appState.lastHeartbeat > appState.heartbeatTimeout * 1000 * 1000) {
+ logInfo(s"Application $appId timed out. Removing shuffle files.")
+ connectedApps.remove(appId)
+ applicationRemoved(appId, true)
+ }
+ }
+ }
}
}
@@ -93,7 +113,8 @@ private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManage
protected override def newShuffleBlockHandler(
conf: TransportConf): ExternalShuffleBlockHandler = {
- new MesosExternalShuffleBlockHandler(conf)
+ val cleanerIntervalS = this.conf.getTimeAsSeconds("spark.shuffle.cleaner.interval", "30s")
+ new MesosExternalShuffleBlockHandler(conf, cleanerIntervalS)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index e1180980ee..90b1813750 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -448,7 +448,12 @@ private[spark] class CoarseMesosSchedulerBackend(
s"host ${slave.hostname}, port $externalShufflePort for app ${conf.getAppId}")
mesosExternalShuffleClient.get
- .registerDriverWithShuffleService(slave.hostname, externalShufflePort)
+ .registerDriverWithShuffleService(
+ slave.hostname,
+ externalShufflePort,
+ sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs",
+ s"${sc.conf.getTimeAsMs("spark.network.timeout", "120s")}ms"),
+ sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s"))
slave.shuffleRegistered = true
}
@@ -506,6 +511,9 @@ private[spark] class CoarseMesosSchedulerBackend(
+ "on the mesos nodes.")
}
+ // Close the mesos external shuffle client if used
+ mesosExternalShuffleClient.foreach(_.close())
+
if (mesosDriver != null) {
mesosDriver.stop()
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala
index dd76644288..b18f0eb162 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala
@@ -192,7 +192,8 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite
val status2 = createTaskStatus("1", "s1", TaskState.TASK_RUNNING)
backend.statusUpdate(driver, status2)
- verify(externalShuffleClient, times(1)).registerDriverWithShuffleService(anyString, anyInt)
+ verify(externalShuffleClient, times(1))
+ .registerDriverWithShuffleService(anyString, anyInt, anyLong, anyLong)
}
test("mesos kills an executor when told") {