aboutsummaryrefslogtreecommitdiff
path: root/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
blob: ea726e3c8240ea141ad41f1e53a10d1a312c30c7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.network.yarn;

import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.util.List;
import java.util.Map;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Objects;
import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.permission.FsPermission;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.server.api.*;
import org.apache.spark.network.util.LevelDBProvider;
import org.iq80.leveldb.DB;
import org.iq80.leveldb.DBIterator;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.TransportContext;
import org.apache.spark.network.sasl.SaslServerBootstrap;
import org.apache.spark.network.sasl.ShuffleSecretManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
import org.apache.spark.network.util.TransportConf;
import org.apache.spark.network.yarn.util.HadoopConfigProvider;

/**
 * An external shuffle service used by Spark on Yarn.
 *
 * This is intended to be a long-running auxiliary service that runs in the NodeManager process.
 * A Spark application may connect to this service by setting `spark.shuffle.service.enabled`.
 * The application also automatically derives the service port through `spark.shuffle.service.port`
 * specified in the Yarn configuration. This is so that both the clients and the server agree on
 * the same port to communicate on.
 *
 * The service also optionally supports authentication. This ensures that executors from one
 * application cannot read the shuffle files written by those from another. This feature can be
 * enabled by setting `spark.authenticate` in the Yarn configuration before starting the NM.
 * Note that the Spark application must also set `spark.authenticate` manually and, unlike in
 * the case of the service port, will not inherit this setting from the Yarn configuration. This
 * is because an application running on the same Yarn cluster may choose to not use the external
 * shuffle service, in which case its setting of `spark.authenticate` should be independent of
 * the service's.
 */
public class YarnShuffleService extends AuxiliaryService {
  private static final Logger logger = LoggerFactory.getLogger(YarnShuffleService.class);

  // Port on which the shuffle server listens for fetch requests
  private static final String SPARK_SHUFFLE_SERVICE_PORT_KEY = "spark.shuffle.service.port";
  private static final int DEFAULT_SPARK_SHUFFLE_SERVICE_PORT = 7337;

  // Whether the shuffle server should authenticate fetch requests
  private static final String SPARK_AUTHENTICATE_KEY = "spark.authenticate";
  private static final boolean DEFAULT_SPARK_AUTHENTICATE = false;

  private static final String RECOVERY_FILE_NAME = "registeredExecutors.ldb";
  private static final String SECRETS_RECOVERY_FILE_NAME = "sparkShuffleRecovery.ldb";

  // Whether failure during service initialization should stop the NM.
  @VisibleForTesting
  static final String STOP_ON_FAILURE_KEY = "spark.yarn.shuffle.stopOnFailure";
  private static final boolean DEFAULT_STOP_ON_FAILURE = false;

  // just for testing when you want to find an open port
  @VisibleForTesting
  static int boundPort = -1;
  private static final ObjectMapper mapper = new ObjectMapper();
  private static final String APP_CREDS_KEY_PREFIX = "AppCreds";
  private static final LevelDBProvider.StoreVersion CURRENT_VERSION = new LevelDBProvider
      .StoreVersion(1, 0);

  // just for integration tests that want to look at this file -- in general not sensible as
  // a static
  @VisibleForTesting
  static YarnShuffleService instance;

  // An entity that manages the shuffle secret per application
  // This is used only if authentication is enabled
  @VisibleForTesting
  ShuffleSecretManager secretManager;

  // The actual server that serves shuffle files
  private TransportServer shuffleServer = null;

  private Configuration _conf = null;

  // The recovery path used to shuffle service recovery
  @VisibleForTesting
  Path _recoveryPath = null;

  // Handles registering executors and opening shuffle blocks
  @VisibleForTesting
  ExternalShuffleBlockHandler blockHandler;

  // Where to store & reload executor info for recovering state after an NM restart
  @VisibleForTesting
  File registeredExecutorFile;

  // Where to store & reload application secrets for recovering state after an NM restart
  @VisibleForTesting
  File secretsFile;

  private DB db;

  public YarnShuffleService() {
    super("spark_shuffle");
    logger.info("Initializing YARN shuffle service for Spark");
    instance = this;
  }

  /**
   * Return whether authentication is enabled as specified by the configuration.
   * If so, fetch requests will fail unless the appropriate authentication secret
   * for the application is provided.
   */
  private boolean isAuthenticationEnabled() {
    return secretManager != null;
  }

  /**
   * Start the shuffle server with the given configuration.
   */
  @Override
  protected void serviceInit(Configuration conf) throws Exception {
    _conf = conf;

    boolean stopOnFailure = conf.getBoolean(STOP_ON_FAILURE_KEY, DEFAULT_STOP_ON_FAILURE);

    try {
      // In case this NM was killed while there were running spark applications, we need to restore
      // lost state for the existing executors. We look for an existing file in the NM's local dirs.
      // If we don't find one, then we choose a file to use to save the state next time.  Even if
      // an application was stopped while the NM was down, we expect yarn to call stopApplication()
      // when it comes back
      registeredExecutorFile = initRecoveryDb(RECOVERY_FILE_NAME);

      TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf));
      blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile);

      // If authentication is enabled, set up the shuffle server to use a
      // special RPC handler that filters out unauthenticated fetch requests
      List<TransportServerBootstrap> bootstraps = Lists.newArrayList();
      boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE);
      if (authEnabled) {
        createSecretManager();
        bootstraps.add(new SaslServerBootstrap(transportConf, secretManager));
      }

      int port = conf.getInt(
        SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT);
      TransportContext transportContext = new TransportContext(transportConf, blockHandler);
      shuffleServer = transportContext.createServer(port, bootstraps);
      // the port should normally be fixed, but for tests its useful to find an open port
      port = shuffleServer.getPort();
      boundPort = port;
      String authEnabledString = authEnabled ? "enabled" : "not enabled";
      logger.info("Started YARN shuffle service for Spark on port {}. " +
        "Authentication is {}.  Registered executor file is {}", port, authEnabledString,
        registeredExecutorFile);
    } catch (Exception e) {
      if (stopOnFailure) {
        throw e;
      } else {
        noteFailure(e);
      }
    }
  }

  private void createSecretManager() throws IOException {
    secretManager = new ShuffleSecretManager();
    secretsFile = initRecoveryDb(SECRETS_RECOVERY_FILE_NAME);

    // Make sure this is protected in case its not in the NM recovery dir
    FileSystem fs = FileSystem.getLocal(_conf);
    fs.mkdirs(new Path(secretsFile.getPath()), new FsPermission((short)0700));

    db = LevelDBProvider.initLevelDB(secretsFile, CURRENT_VERSION, mapper);
    logger.info("Recovery location is: " + secretsFile.getPath());
    if (db != null) {
      logger.info("Going to reload spark shuffle data");
      DBIterator itr = db.iterator();
      itr.seek(APP_CREDS_KEY_PREFIX.getBytes(StandardCharsets.UTF_8));
      while (itr.hasNext()) {
        Map.Entry<byte[], byte[]> e = itr.next();
        String key = new String(e.getKey(), StandardCharsets.UTF_8);
        if (!key.startsWith(APP_CREDS_KEY_PREFIX)) {
          break;
        }
        String id = parseDbAppKey(key);
        ByteBuffer secret = mapper.readValue(e.getValue(), ByteBuffer.class);
        logger.info("Reloading tokens for app: " + id);
        secretManager.registerApp(id, secret);
      }
    }
  }

  private static String parseDbAppKey(String s) throws IOException {
    if (!s.startsWith(APP_CREDS_KEY_PREFIX)) {
      throw new IllegalArgumentException("expected a string starting with " + APP_CREDS_KEY_PREFIX);
    }
    String json = s.substring(APP_CREDS_KEY_PREFIX.length() + 1);
    AppId parsed = mapper.readValue(json, AppId.class);
    return parsed.appId;
  }

  private static byte[] dbAppKey(AppId appExecId) throws IOException {
    // we stick a common prefix on all the keys so we can find them in the DB
    String appExecJson = mapper.writeValueAsString(appExecId);
    String key = (APP_CREDS_KEY_PREFIX + ";" + appExecJson);
    return key.getBytes(StandardCharsets.UTF_8);
  }

  @Override
  public void initializeApplication(ApplicationInitializationContext context) {
    String appId = context.getApplicationId().toString();
    try {
      ByteBuffer shuffleSecret = context.getApplicationDataForService();
      logger.info("Initializing application {}", appId);
      if (isAuthenticationEnabled()) {
        AppId fullId = new AppId(appId);
        if (db != null) {
          byte[] key = dbAppKey(fullId);
          byte[] value = mapper.writeValueAsString(shuffleSecret).getBytes(StandardCharsets.UTF_8);
          db.put(key, value);
        }
        secretManager.registerApp(appId, shuffleSecret);
      }
    } catch (Exception e) {
      logger.error("Exception when initializing application {}", appId, e);
    }
  }

  @Override
  public void stopApplication(ApplicationTerminationContext context) {
    String appId = context.getApplicationId().toString();
    try {
      logger.info("Stopping application {}", appId);
      if (isAuthenticationEnabled()) {
        AppId fullId = new AppId(appId);
        if (db != null) {
          try {
            db.delete(dbAppKey(fullId));
          } catch (IOException e) {
            logger.error("Error deleting {} from executor state db", appId, e);
          }
        }
        secretManager.unregisterApp(appId);
      }
      blockHandler.applicationRemoved(appId, false /* clean up local dirs */);
    } catch (Exception e) {
      logger.error("Exception when stopping application {}", appId, e);
    }
  }

  @Override
  public void initializeContainer(ContainerInitializationContext context) {
    ContainerId containerId = context.getContainerId();
    logger.info("Initializing container {}", containerId);
  }

  @Override
  public void stopContainer(ContainerTerminationContext context) {
    ContainerId containerId = context.getContainerId();
    logger.info("Stopping container {}", containerId);
  }

  /**
   * Close the shuffle server to clean up any associated state.
   */
  @Override
  protected void serviceStop() {
    try {
      if (shuffleServer != null) {
        shuffleServer.close();
      }
      if (blockHandler != null) {
        blockHandler.close();
      }
      if (db != null) {
        db.close();
      }
    } catch (Exception e) {
      logger.error("Exception when stopping service", e);
    }
  }

  // Not currently used
  @Override
  public ByteBuffer getMetaData() {
    return ByteBuffer.allocate(0);
  }

  /**
   * Set the recovery path for shuffle service recovery when NM is restarted. The method will be
   * overrode and called when Hadoop version is 2.5+ and NM recovery is enabled, otherwise we
   * have to manually call this to set our own recovery path.
   */
  public void setRecoveryPath(Path recoveryPath) {
    _recoveryPath = recoveryPath;
  }

  /**
   * Get the path specific to this auxiliary service to use for recovery.
   */
  protected Path getRecoveryPath(String fileName) {
    return _recoveryPath;
  }

  /**
   * Figure out the recovery path and handle moving the DB if YARN NM recovery gets enabled
   * when it previously was not. If YARN NM recovery is enabled it uses that path, otherwise
   * it will uses a YARN local dir.
   */
  protected File initRecoveryDb(String dbFileName) {
    if (_recoveryPath != null) {
        File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbFileName);
        if (recoveryFile.exists()) {
          return recoveryFile;
        }
    }
    // db doesn't exist in recovery path go check local dirs for it
    String[] localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs");
    for (String dir : localDirs) {
      File f = new File(new Path(dir).toUri().getPath(), dbFileName);
      if (f.exists()) {
        if (_recoveryPath == null) {
          // If NM recovery is not enabled, we should specify the recovery path using NM local
          // dirs, which is compatible with the old code.
          _recoveryPath = new Path(dir);
          return f;
        } else {
          // If the recovery path is set then either NM recovery is enabled or another recovery
          // DB has been initialized. If NM recovery is enabled and had set the recovery path
          // make sure to move all DBs to the recovery path from the old NM local dirs.
          // If another DB was initialized first just make sure all the DBs are in the same
          // location.
          File newLoc = new File(_recoveryPath.toUri().getPath(), dbFileName);
          if (!newLoc.equals(f)) {
            try {
              Files.move(f.toPath(), newLoc.toPath());
            } catch (Exception e) {
              // Fail to move recovery file to new path, just continue on with new DB location
              logger.error("Failed to move recovery file {} to the path {}",
                dbFileName, _recoveryPath.toString(), e);
            }
          }
          return newLoc;
        }
      }
    }
    if (_recoveryPath == null) {
      _recoveryPath = new Path(localDirs[0]);
    }

    return new File(_recoveryPath.toUri().getPath(), dbFileName);
  }

  /**
   * Simply encodes an application ID.
   */
  public static class AppId {
    public final String appId;

    @JsonCreator
    public AppId(@JsonProperty("appId") String appId) {
      this.appId = appId;
    }

    @Override
    public boolean equals(Object o) {
      if (this == o) return true;
      if (o == null || getClass() != o.getClass()) return false;

      AppId appExecId = (AppId) o;
      return Objects.equal(appId, appExecId.appId);
    }

    @Override
    public int hashCode() {
      return Objects.hashCode(appId);
    }

    @Override
    public String toString() {
      return Objects.toStringHelper(this)
          .add("appId", appId)
          .toString();
    }
  }

}