aboutsummaryrefslogtreecommitdiff
path: root/common/network-shuffle
diff options
context:
space:
mode:
Diffstat (limited to 'common/network-shuffle')
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java76
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java5
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java14
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java53
4 files changed, 130 insertions, 18 deletions
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
index 675820308b..2add9c83a7 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
@@ -19,7 +19,12 @@ package org.apache.spark.network.shuffle.mesos;
import java.io.IOException;
import java.nio.ByteBuffer;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -41,6 +46,13 @@ import org.apache.spark.network.util.TransportConf;
public class MesosExternalShuffleClient extends ExternalShuffleClient {
private final Logger logger = LoggerFactory.getLogger(MesosExternalShuffleClient.class);
+ private final ScheduledExecutorService heartbeaterThread =
+ Executors.newSingleThreadScheduledExecutor(
+ new ThreadFactoryBuilder()
+ .setDaemon(true)
+ .setNameFormat("mesos-external-shuffle-client-heartbeater")
+ .build());
+
/**
* Creates an Mesos external shuffle client that wraps the {@link ExternalShuffleClient}.
* Please refer to docs on {@link ExternalShuffleClient} for more information.
@@ -53,21 +65,59 @@ public class MesosExternalShuffleClient extends ExternalShuffleClient {
super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled);
}
- public void registerDriverWithShuffleService(String host, int port) throws IOException {
+ public void registerDriverWithShuffleService(
+ String host,
+ int port,
+ long heartbeatTimeoutMs,
+ long heartbeatIntervalMs) throws IOException {
+
checkInit();
- ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer();
+ ByteBuffer registerDriver = new RegisterDriver(appId, heartbeatTimeoutMs).toByteBuffer();
TransportClient client = clientFactory.createClient(host, port);
- client.sendRpc(registerDriver, new RpcResponseCallback() {
- @Override
- public void onSuccess(ByteBuffer response) {
- logger.info("Successfully registered app " + appId + " with external shuffle service.");
- }
-
- @Override
- public void onFailure(Throwable e) {
- logger.warn("Unable to register app " + appId + " with external shuffle service. " +
+ client.sendRpc(registerDriver, new RegisterDriverCallback(client, heartbeatIntervalMs));
+ }
+
+ private class RegisterDriverCallback implements RpcResponseCallback {
+ private final TransportClient client;
+ private final long heartbeatIntervalMs;
+
+ private RegisterDriverCallback(TransportClient client, long heartbeatIntervalMs) {
+ this.client = client;
+ this.heartbeatIntervalMs = heartbeatIntervalMs;
+ }
+
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ heartbeaterThread.scheduleAtFixedRate(
+ new Heartbeater(client), 0, heartbeatIntervalMs, TimeUnit.MILLISECONDS);
+ logger.info("Successfully registered app " + appId + " with external shuffle service.");
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ logger.warn("Unable to register app " + appId + " with external shuffle service. " +
"Please manually remove shuffle data after driver exit. Error: " + e);
- }
- });
+ }
+ }
+
+ @Override
+ public void close() {
+ heartbeaterThread.shutdownNow();
+ super.close();
+ }
+
+ private class Heartbeater implements Runnable {
+
+ private final TransportClient client;
+
+ private Heartbeater(TransportClient client) {
+ this.client = client;
+ }
+
+ @Override
+ public void run() {
+ // TODO: Stop sending heartbeats if the shuffle service has lost the app due to timeout
+ client.send(new ShuffleServiceHeartbeat(appId).toByteBuffer());
+ }
}
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
index 7fbe3384b4..21c0ff4136 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
@@ -24,6 +24,7 @@ import io.netty.buffer.Unpooled;
import org.apache.spark.network.protocol.Encodable;
import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver;
+import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat;
/**
* Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or
@@ -40,7 +41,8 @@ public abstract class BlockTransferMessage implements Encodable {
/** Preceding every serialized message is its type, which allows us to deserialize it. */
public static enum Type {
- OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4);
+ OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4),
+ HEARTBEAT(5);
private final byte id;
@@ -64,6 +66,7 @@ public abstract class BlockTransferMessage implements Encodable {
case 2: return RegisterExecutor.decode(buf);
case 3: return StreamHandle.decode(buf);
case 4: return RegisterDriver.decode(buf);
+ case 5: return ShuffleServiceHeartbeat.decode(buf);
default: throw new IllegalArgumentException("Unknown message type: " + type);
}
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java
index eeb0019411..d5f53ccb7f 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java
@@ -31,29 +31,34 @@ import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Typ
*/
public class RegisterDriver extends BlockTransferMessage {
private final String appId;
+ private final long heartbeatTimeoutMs;
- public RegisterDriver(String appId) {
+ public RegisterDriver(String appId, long heartbeatTimeoutMs) {
this.appId = appId;
+ this.heartbeatTimeoutMs = heartbeatTimeoutMs;
}
public String getAppId() { return appId; }
+ public long getHeartbeatTimeoutMs() { return heartbeatTimeoutMs; }
+
@Override
protected Type type() { return Type.REGISTER_DRIVER; }
@Override
public int encodedLength() {
- return Encoders.Strings.encodedLength(appId);
+ return Encoders.Strings.encodedLength(appId) + Long.SIZE / Byte.SIZE;
}
@Override
public void encode(ByteBuf buf) {
Encoders.Strings.encode(buf, appId);
+ buf.writeLong(heartbeatTimeoutMs);
}
@Override
public int hashCode() {
- return Objects.hashCode(appId);
+ return Objects.hashCode(appId, heartbeatTimeoutMs);
}
@Override
@@ -66,6 +71,7 @@ public class RegisterDriver extends BlockTransferMessage {
public static RegisterDriver decode(ByteBuf buf) {
String appId = Encoders.Strings.decode(buf);
- return new RegisterDriver(appId);
+ long heartbeatTimeout = buf.readLong();
+ return new RegisterDriver(appId, heartbeatTimeout);
}
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java
new file mode 100644
index 0000000000..b30bb9aed5
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol.mesos;
+
+import io.netty.buffer.ByteBuf;
+import org.apache.spark.network.protocol.Encoders;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+/**
+ * A heartbeat sent from the driver to the MesosExternalShuffleService.
+ */
+public class ShuffleServiceHeartbeat extends BlockTransferMessage {
+ private final String appId;
+
+ public ShuffleServiceHeartbeat(String appId) {
+ this.appId = appId;
+ }
+
+ public String getAppId() { return appId; }
+
+ @Override
+ protected Type type() { return Type.HEARTBEAT; }
+
+ @Override
+ public int encodedLength() { return Encoders.Strings.encodedLength(appId); }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ }
+
+ public static ShuffleServiceHeartbeat decode(ByteBuf buf) {
+ return new ShuffleServiceHeartbeat(Encoders.Strings.decode(buf));
+ }
+}