aboutsummaryrefslogtreecommitdiff
path: root/network/shuffle/src
diff options
context:
space:
mode:
authorAaron Davidson <aaron@databricks.com>2014-11-04 16:15:38 -0800
committerReynold Xin <rxin@databricks.com>2014-11-04 16:15:38 -0800
commit5e73138a0152b78380b3f1def4b969b58e70dd11 (patch)
treed899e389f29e7c87c8f40e84872477a9b5e52277 /network/shuffle/src
parentf90ad5d426cb726079c490a9bb4b1100e2b4e602 (diff)
downloadspark-5e73138a0152b78380b3f1def4b969b58e70dd11.tar.gz
spark-5e73138a0152b78380b3f1def4b969b58e70dd11.tar.bz2
spark-5e73138a0152b78380b3f1def4b969b58e70dd11.zip
[SPARK-2938] Support SASL authentication in NettyBlockTransferService
Also lays the groundwork for supporting it inside the external shuffle service. Author: Aaron Davidson <aaron@databricks.com> Closes #3087 from aarondav/sasl and squashes the following commits: 3481718 [Aaron Davidson] Delete rogue println 44f8410 [Aaron Davidson] Delete documentation - muahaha! eb9f065 [Aaron Davidson] Improve documentation and add end-to-end test at Spark-level a6b95f1 [Aaron Davidson] Address comments 785bbde [Aaron Davidson] Cleanup 79973cb [Aaron Davidson] Remove unused file 151b3c5 [Aaron Davidson] Add docs, timeout config, better failure handling f6177d7 [Aaron Davidson] Cleanup SASL state upon connection termination 7b42adb [Aaron Davidson] Add unit tests 8191bcb [Aaron Davidson] [SPARK-2938] Support SASL authentication in NettyBlockTransferService
Diffstat (limited to 'network/shuffle/src')
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java74
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java74
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java97
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java35
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java138
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java170
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java2
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java15
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java11
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java172
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java89
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java7
12 files changed, 874 insertions, 10 deletions
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
new file mode 100644
index 0000000000..7bc91e3753
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The
+ * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId.
+ */
+public class SaslClientBootstrap implements TransportClientBootstrap {
+ private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);
+
+ private final TransportConf conf;
+ private final String appId;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.appId = appId;
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ /**
+ * Performs SASL authentication by sending a token, and then proceeding with the SASL
+ * challenge-response tokens until we either successfully authenticate or throw an exception
+ * due to mismatch.
+ */
+ @Override
+ public void doBootstrap(TransportClient client) {
+ SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder);
+ try {
+ byte[] payload = saslClient.firstToken();
+
+ while (!saslClient.isComplete()) {
+ SaslMessage msg = new SaslMessage(appId, payload);
+ ByteBuf buf = Unpooled.buffer(msg.encodedLength());
+ msg.encode(buf);
+
+ byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeout());
+ payload = saslClient.response(response);
+ }
+ } finally {
+ try {
+ // Once authentication is complete, the server will trust all remaining communication.
+ saslClient.dispose();
+ } catch (RuntimeException e) {
+ logger.error("Error while disposing SASL client", e);
+ }
+ }
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
new file mode 100644
index 0000000000..5b77e18c26
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import com.google.common.base.Charsets;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encodable;
+
+/**
+ * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
+ * with the given appId. This appId allows a single SaslRpcHandler to multiplex different
+ * applications which may be using different sets of credentials.
+ */
+class SaslMessage implements Encodable {
+
+ /** Serialization tag used to catch incorrect payloads. */
+ private static final byte TAG_BYTE = (byte) 0xEA;
+
+ public final String appId;
+ public final byte[] payload;
+
+ public SaslMessage(String appId, byte[] payload) {
+ this.appId = appId;
+ this.payload = payload;
+ }
+
+ @Override
+ public int encodedLength() {
+ // tag + appIdLength + appId + payloadLength + payload
+ return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeByte(TAG_BYTE);
+ byte[] idBytes = appId.getBytes(Charsets.UTF_8);
+ buf.writeInt(idBytes.length);
+ buf.writeBytes(idBytes);
+ buf.writeInt(payload.length);
+ buf.writeBytes(payload);
+ }
+
+ public static SaslMessage decode(ByteBuf buf) {
+ if (buf.readByte() != TAG_BYTE) {
+ throw new IllegalStateException("Expected SaslMessage, received something else");
+ }
+
+ int idLength = buf.readInt();
+ byte[] idBytes = new byte[idLength];
+ buf.readBytes(idBytes);
+
+ int payloadLength = buf.readInt();
+ byte[] payload = new byte[payloadLength];
+ buf.readBytes(payload);
+
+ return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload);
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
new file mode 100644
index 0000000000..3777a18e33
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.util.concurrent.ConcurrentMap;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Maps;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+
+/**
+ * RPC Handler which performs SASL authentication before delegating to a child RPC handler.
+ * The delegate will only receive messages if the given connection has been successfully
+ * authenticated. A connection may be authenticated at most once.
+ *
+ * Note that the authentication process consists of multiple challenge-response pairs, each of
+ * which are individual RPCs.
+ */
+public class SaslRpcHandler extends RpcHandler {
+ private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
+
+ /** RpcHandler we will delegate to for authenticated connections. */
+ private final RpcHandler delegate;
+
+ /** Class which provides secret keys which are shared by server and client on a per-app basis. */
+ private final SecretKeyHolder secretKeyHolder;
+
+ /** Maps each channel to its SASL authentication state. */
+ private final ConcurrentMap<TransportClient, SparkSaslServer> channelAuthenticationMap;
+
+ public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) {
+ this.delegate = delegate;
+ this.secretKeyHolder = secretKeyHolder;
+ this.channelAuthenticationMap = Maps.newConcurrentMap();
+ }
+
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ SparkSaslServer saslServer = channelAuthenticationMap.get(client);
+ if (saslServer != null && saslServer.isComplete()) {
+ // Authentication complete, delegate to base handler.
+ delegate.receive(client, message, callback);
+ return;
+ }
+
+ SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message));
+
+ if (saslServer == null) {
+ // First message in the handshake, setup the necessary state.
+ saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder);
+ channelAuthenticationMap.put(client, saslServer);
+ }
+
+ byte[] response = saslServer.response(saslMessage.payload);
+ if (saslServer.isComplete()) {
+ logger.debug("SASL authentication successful for channel {}", client);
+ }
+ callback.onSuccess(response);
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return delegate.getStreamManager();
+ }
+
+ @Override
+ public void connectionTerminated(TransportClient client) {
+ SparkSaslServer saslServer = channelAuthenticationMap.remove(client);
+ if (saslServer != null) {
+ saslServer.dispose();
+ }
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
new file mode 100644
index 0000000000..81d5766794
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+/**
+ * Interface for getting a secret key associated with some application.
+ */
+public interface SecretKeyHolder {
+ /**
+ * Gets an appropriate SASL User for the given appId.
+ * @throws IllegalArgumentException if the given appId is not associated with a SASL user.
+ */
+ String getSaslUser(String appId);
+
+ /**
+ * Gets an appropriate SASL secret key for the given appId.
+ * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key.
+ */
+ String getSecretKey(String appId);
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
new file mode 100644
index 0000000000..72ba737b99
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.RealmChoiceCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+import java.io.IOException;
+
+import com.google.common.base.Throwables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.spark.network.sasl.SparkSaslServer.*;
+
+/**
+ * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the
+ * initial state to the "authenticated" state. This client initializes the protocol via a
+ * firstToken, which is then followed by a set of challenges and responses.
+ */
+public class SparkSaslClient {
+ private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class);
+
+ private final String secretKeyId;
+ private final SecretKeyHolder secretKeyHolder;
+ private SaslClient saslClient;
+
+ public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) {
+ this.secretKeyId = secretKeyId;
+ this.secretKeyHolder = secretKeyHolder;
+ try {
+ this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM,
+ SASL_PROPS, new ClientCallbackHandler());
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /** Used to initiate SASL handshake with server. */
+ public synchronized byte[] firstToken() {
+ if (saslClient != null && saslClient.hasInitialResponse()) {
+ try {
+ return saslClient.evaluateChallenge(new byte[0]);
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ } else {
+ return new byte[0];
+ }
+ }
+
+ /** Determines whether the authentication exchange has completed. */
+ public synchronized boolean isComplete() {
+ return saslClient != null && saslClient.isComplete();
+ }
+
+ /**
+ * Respond to server's SASL token.
+ * @param token contains server's SASL token
+ * @return client's response SASL token
+ */
+ public synchronized byte[] response(byte[] token) {
+ try {
+ return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0];
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslClient might be using.
+ */
+ public synchronized void dispose() {
+ if (saslClient != null) {
+ try {
+ saslClient.dispose();
+ } catch (SaslException e) {
+ // ignore
+ } finally {
+ saslClient = null;
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * that works with share secrets.
+ */
+ private class ClientCallbackHandler implements CallbackHandler {
+ @Override
+ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+
+ for (Callback callback : callbacks) {
+ if (callback instanceof NameCallback) {
+ logger.trace("SASL client callback: setting username");
+ NameCallback nc = (NameCallback) callback;
+ nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
+ } else if (callback instanceof PasswordCallback) {
+ logger.trace("SASL client callback: setting password");
+ PasswordCallback pc = (PasswordCallback) callback;
+ pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
+ } else if (callback instanceof RealmCallback) {
+ logger.trace("SASL client callback: setting realm");
+ RealmCallback rc = (RealmCallback) callback;
+ rc.setText(rc.getDefaultText());
+ logger.info("Realm callback");
+ } else if (callback instanceof RealmChoiceCallback) {
+ // ignore (?)
+ } else {
+ throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
+ }
+ }
+ }
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
new file mode 100644
index 0000000000..2c0ce40c75
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.AuthorizeCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+import java.io.IOException;
+import java.util.Map;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.io.BaseEncoding;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the
+ * initial state to the "authenticated" state. (It is not a server in the sense of accepting
+ * connections on some socket.)
+ */
+public class SparkSaslServer {
+ private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class);
+
+ /**
+ * This is passed as the server name when creating the sasl client/server.
+ * This could be changed to be configurable in the future.
+ */
+ static final String DEFAULT_REALM = "default";
+
+ /**
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ static final String DIGEST = "DIGEST-MD5";
+
+ /**
+ * The quality of protection is just "auth". This means that we are doing
+ * authentication only, we are not supporting integrity or privacy protection of the
+ * communication channel after authentication. This could be changed to be configurable
+ * in the future.
+ */
+ static final Map<String, String> SASL_PROPS = ImmutableMap.<String, String>builder()
+ .put(Sasl.QOP, "auth")
+ .put(Sasl.SERVER_AUTH, "true")
+ .build();
+
+ /** Identifier for a certain secret key within the secretKeyHolder. */
+ private final String secretKeyId;
+ private final SecretKeyHolder secretKeyHolder;
+ private SaslServer saslServer;
+
+ public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) {
+ this.secretKeyId = secretKeyId;
+ this.secretKeyHolder = secretKeyHolder;
+ try {
+ this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS,
+ new DigestCallbackHandler());
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Determines whether the authentication exchange has completed successfully.
+ */
+ public synchronized boolean isComplete() {
+ return saslServer != null && saslServer.isComplete();
+ }
+
+ /**
+ * Used to respond to server SASL tokens.
+ * @param token Server's SASL token
+ * @return response to send back to the server.
+ */
+ public synchronized byte[] response(byte[] token) {
+ try {
+ return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0];
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslServer might be using.
+ */
+ public synchronized void dispose() {
+ if (saslServer != null) {
+ try {
+ saslServer.dispose();
+ } catch (SaslException e) {
+ // ignore
+ } finally {
+ saslServer = null;
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism.
+ */
+ private class DigestCallbackHandler implements CallbackHandler {
+ @Override
+ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+ for (Callback callback : callbacks) {
+ if (callback instanceof NameCallback) {
+ logger.trace("SASL server callback: setting username");
+ NameCallback nc = (NameCallback) callback;
+ nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
+ } else if (callback instanceof PasswordCallback) {
+ logger.trace("SASL server callback: setting password");
+ PasswordCallback pc = (PasswordCallback) callback;
+ pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
+ } else if (callback instanceof RealmCallback) {
+ logger.trace("SASL server callback: setting realm");
+ RealmCallback rc = (RealmCallback) callback;
+ rc.setText(rc.getDefaultText());
+ } else if (callback instanceof AuthorizeCallback) {
+ AuthorizeCallback ac = (AuthorizeCallback) callback;
+ String authId = ac.getAuthenticationID();
+ String authzId = ac.getAuthorizationID();
+ ac.setAuthorized(authId.equals(authzId));
+ if (ac.isAuthorized()) {
+ ac.setAuthorizedID(authzId);
+ }
+ logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized());
+ } else {
+ throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
+ }
+ }
+ }
+ }
+
+ /* Encode a byte[] identifier as a Base64-encoded string. */
+ public static String encodeIdentifier(String identifier) {
+ Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled");
+ return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8));
+ }
+
+ /** Encode a password as a base64-encoded char[] array. */
+ public static char[] encodePassword(String password) {
+ Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled");
+ return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray();
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index a9dff31dec..cd3fea85b1 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -41,7 +41,7 @@ import org.apache.spark.network.util.JavaUtils;
* with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark-
* level shuffle block.
*/
-public class ExternalShuffleBlockHandler implements RpcHandler {
+public class ExternalShuffleBlockHandler extends RpcHandler {
private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class);
private final ExternalShuffleBlockManager blockManager;
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 6bbabc44b9..b0b19ba67b 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -17,8 +17,6 @@
package org.apache.spark.network.shuffle;
-import java.io.Closeable;
-
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -36,15 +34,20 @@ import org.apache.spark.network.util.TransportConf;
* BlockTransferService), which has the downside of losing the shuffle data if we lose the
* executors.
*/
-public class ExternalShuffleClient implements ShuffleClient {
+public class ExternalShuffleClient extends ShuffleClient {
private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class);
private final TransportClientFactory clientFactory;
- private final String appId;
- public ExternalShuffleClient(TransportConf conf, String appId) {
+ private String appId;
+
+ public ExternalShuffleClient(TransportConf conf) {
TransportContext context = new TransportContext(conf, new NoOpRpcHandler());
this.clientFactory = context.createClientFactory();
+ }
+
+ @Override
+ public void init(String appId) {
this.appId = appId;
}
@@ -55,6 +58,7 @@ public class ExternalShuffleClient implements ShuffleClient {
String execId,
String[] blockIds,
BlockFetchingListener listener) {
+ assert appId != null : "Called before init()";
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
TransportClient client = clientFactory.createClient(host, port);
@@ -82,6 +86,7 @@ public class ExternalShuffleClient implements ShuffleClient {
int port,
String execId,
ExecutorShuffleInfo executorInfo) {
+ assert appId != null : "Called before init()";
TransportClient client = clientFactory.createClient(host, port);
byte[] registerExecutorMessage =
JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo));
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
index d46a562394..f72ab40690 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
@@ -20,7 +20,14 @@ package org.apache.spark.network.shuffle;
import java.io.Closeable;
/** Provides an interface for reading shuffle files, either from an Executor or external service. */
-public interface ShuffleClient extends Closeable {
+public abstract class ShuffleClient implements Closeable {
+
+ /**
+ * Initializes the ShuffleClient, specifying this Executor's appId.
+ * Must be called before any other method on the ShuffleClient.
+ */
+ public void init(String appId) { }
+
/**
* Fetch a sequence of blocks from a remote node asynchronously,
*
@@ -28,7 +35,7 @@ public interface ShuffleClient extends Closeable {
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
* the data of a block is fetched, rather than waiting for all blocks to be fetched.
*/
- public void fetchBlocks(
+ public abstract void fetchBlocks(
String host,
int port,
String execId,
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
new file mode 100644
index 0000000000..8478120786
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class SaslIntegrationSuite {
+ static ExternalShuffleBlockHandler handler;
+ static TransportServer server;
+ static TransportConf conf;
+ static TransportContext context;
+
+ TransportClientFactory clientFactory;
+
+ /** Provides a secret key holder which always returns the given secret key. */
+ static class TestSecretKeyHolder implements SecretKeyHolder {
+
+ private final String secretKey;
+
+ TestSecretKeyHolder(String secretKey) {
+ this.secretKey = secretKey;
+ }
+
+ @Override
+ public String getSaslUser(String appId) {
+ return "user";
+ }
+ @Override
+ public String getSecretKey(String appId) {
+ return secretKey;
+ }
+ }
+
+
+ @BeforeClass
+ public static void beforeAll() throws IOException {
+ SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key");
+ SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder);
+ conf = new TransportConf(new SystemPropertyConfigProvider());
+ context = new TransportContext(conf, handler);
+ server = context.createServer();
+ }
+
+
+ @AfterClass
+ public static void afterAll() {
+ server.close();
+ }
+
+ @After
+ public void afterEach() {
+ if (clientFactory != null) {
+ clientFactory.close();
+ clientFactory = null;
+ }
+ }
+
+ @Test
+ public void testGoodClient() {
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key"))));
+
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ String msg = "Hello, World!";
+ byte[] resp = client.sendRpcSync(msg.getBytes(), 1000);
+ assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg
+ }
+
+ @Test
+ public void testBadClient() {
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key"))));
+
+ try {
+ // Bootstrap should fail on startup.
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
+ }
+ }
+
+ @Test
+ public void testNoSaslClient() {
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList());
+
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ client.sendRpcSync(new byte[13], 1000);
+ fail("Should have failed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage"));
+ }
+
+ try {
+ // Guessing the right tag byte doesn't magically get you in...
+ client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000);
+ fail("Should have failed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException"));
+ }
+ }
+
+ @Test
+ public void testNoSaslServer() {
+ RpcHandler handler = new TestRpcHandler();
+ TransportContext context = new TransportContext(conf, handler);
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key"))));
+ TransportServer server = context.createServer();
+ try {
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation"));
+ } finally {
+ server.close();
+ }
+ }
+
+ /** RPC handler which simply responds with the message it received. */
+ public static class TestRpcHandler extends RpcHandler {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ callback.onSuccess(message);
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return new OneForOneStreamManager();
+ }
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
new file mode 100644
index 0000000000..67a07f38eb
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.util.Map;
+
+import com.google.common.collect.ImmutableMap;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+/**
+ * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes.
+ */
+public class SparkSaslSuite {
+
+ /** Provides a secret key holder which returns secret key == appId */
+ private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() {
+ @Override
+ public String getSaslUser(String appId) {
+ return "user";
+ }
+
+ @Override
+ public String getSecretKey(String appId) {
+ return appId;
+ }
+ };
+
+ @Test
+ public void testMatching() {
+ SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder);
+ SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder);
+
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+
+ byte[] clientMessage = client.firstToken();
+
+ while (!client.isComplete()) {
+ clientMessage = client.response(server.response(clientMessage));
+ }
+ assertTrue(server.isComplete());
+
+ // Disposal should invalidate
+ server.dispose();
+ assertFalse(server.isComplete());
+ client.dispose();
+ assertFalse(client.isComplete());
+ }
+
+
+ @Test
+ public void testNonMatching() {
+ SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder);
+ SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder);
+
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+
+ byte[] clientMessage = client.firstToken();
+
+ try {
+ while (!client.isComplete()) {
+ clientMessage = client.response(server.response(clientMessage));
+ }
+ fail("Should not have completed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage().contains("Mismatched response"));
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+ }
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index b3bcf5fd68..bc101f5384 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -135,7 +135,8 @@ public class ExternalShuffleIntegrationSuite {
final Semaphore requestsRemaining = new Semaphore(0);
- ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID);
+ ExternalShuffleClient client = new ExternalShuffleClient(conf);
+ client.init(APP_ID);
client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
new BlockFetchingListener() {
@Override
@@ -164,6 +165,7 @@ public class ExternalShuffleIntegrationSuite {
if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) {
fail("Timeout getting response from the server");
}
+ client.close();
return res;
}
@@ -265,7 +267,8 @@ public class ExternalShuffleIntegrationSuite {
}
private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) {
- ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID);
+ ExternalShuffleClient client = new ExternalShuffleClient(conf);
+ client.init(APP_ID);
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
executorId, executorInfo);
}