aboutsummaryrefslogtreecommitdiff
path: root/network/common
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2015-03-11 13:16:22 +0000
committerSean Owen <sowen@cloudera.com>2015-03-11 13:16:22 +0000
commit5b335bdda3efb7c6a5b18b4eeff189064c11e6c3 (patch)
tree6d891475ad6fe0466a38a7032f16bb21fd8d0c47 /network/common
parent6e94c4eadf443ac3d34eaae4c334c8386fdec960 (diff)
downloadspark-5b335bdda3efb7c6a5b18b4eeff189064c11e6c3.tar.gz
spark-5b335bdda3efb7c6a5b18b4eeff189064c11e6c3.tar.bz2
spark-5b335bdda3efb7c6a5b18b4eeff189064c11e6c3.zip
[SPARK-6228] [network] Move SASL classes from network/shuffle to network...
.../common. No code changes. Left the shuffle-related files in the shuffle module. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #4953 from vanzin/SPARK-6228 and squashes the following commits: 664ef30 [Marcelo Vanzin] [SPARK-6228] [network] Move SASL classes from network/shuffle to network/common.
Diffstat (limited to 'network/common')
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java74
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java65
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java94
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java35
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java137
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java173
-rw-r--r--network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java89
7 files changed, 667 insertions, 0 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
new file mode 100644
index 0000000000..33aa134434
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The
+ * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId.
+ */
+public class SaslClientBootstrap implements TransportClientBootstrap {
+ private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);
+
+ private final TransportConf conf;
+ private final String appId;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.appId = appId;
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ /**
+ * Performs SASL authentication by sending a token, and then proceeding with the SASL
+ * challenge-response tokens until we either successfully authenticate or throw an exception
+ * due to mismatch.
+ */
+ @Override
+ public void doBootstrap(TransportClient client) {
+ SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder);
+ try {
+ byte[] payload = saslClient.firstToken();
+
+ while (!saslClient.isComplete()) {
+ SaslMessage msg = new SaslMessage(appId, payload);
+ ByteBuf buf = Unpooled.buffer(msg.encodedLength());
+ msg.encode(buf);
+
+ byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs());
+ payload = saslClient.response(response);
+ }
+ } finally {
+ try {
+ // Once authentication is complete, the server will trust all remaining communication.
+ saslClient.dispose();
+ } catch (RuntimeException e) {
+ logger.error("Error while disposing SASL client", e);
+ }
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
new file mode 100644
index 0000000000..cad76ab7aa
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.Encoders;
+
+/**
+ * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
+ * with the given appId. This appId allows a single SaslRpcHandler to multiplex different
+ * applications which may be using different sets of credentials.
+ */
+class SaslMessage implements Encodable {
+
+ /** Serialization tag used to catch incorrect payloads. */
+ private static final byte TAG_BYTE = (byte) 0xEA;
+
+ public final String appId;
+ public final byte[] payload;
+
+ public SaslMessage(String appId, byte[] payload) {
+ this.appId = appId;
+ this.payload = payload;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeByte(TAG_BYTE);
+ Encoders.Strings.encode(buf, appId);
+ Encoders.ByteArrays.encode(buf, payload);
+ }
+
+ public static SaslMessage decode(ByteBuf buf) {
+ if (buf.readByte() != TAG_BYTE) {
+ throw new IllegalStateException("Expected SaslMessage, received something else"
+ + " (maybe your client does not have SASL enabled?)");
+ }
+
+ String appId = Encoders.Strings.decode(buf);
+ byte[] payload = Encoders.ByteArrays.decode(buf);
+ return new SaslMessage(appId, payload);
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
new file mode 100644
index 0000000000..026cbd260d
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.util.concurrent.ConcurrentMap;
+
+import com.google.common.collect.Maps;
+import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+
+/**
+ * RPC Handler which performs SASL authentication before delegating to a child RPC handler.
+ * The delegate will only receive messages if the given connection has been successfully
+ * authenticated. A connection may be authenticated at most once.
+ *
+ * Note that the authentication process consists of multiple challenge-response pairs, each of
+ * which are individual RPCs.
+ */
+public class SaslRpcHandler extends RpcHandler {
+ private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
+
+ /** RpcHandler we will delegate to for authenticated connections. */
+ private final RpcHandler delegate;
+
+ /** Class which provides secret keys which are shared by server and client on a per-app basis. */
+ private final SecretKeyHolder secretKeyHolder;
+
+ /** Maps each channel to its SASL authentication state. */
+ private final ConcurrentMap<TransportClient, SparkSaslServer> channelAuthenticationMap;
+
+ public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) {
+ this.delegate = delegate;
+ this.secretKeyHolder = secretKeyHolder;
+ this.channelAuthenticationMap = Maps.newConcurrentMap();
+ }
+
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ SparkSaslServer saslServer = channelAuthenticationMap.get(client);
+ if (saslServer != null && saslServer.isComplete()) {
+ // Authentication complete, delegate to base handler.
+ delegate.receive(client, message, callback);
+ return;
+ }
+
+ SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message));
+
+ if (saslServer == null) {
+ // First message in the handshake, setup the necessary state.
+ saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder);
+ channelAuthenticationMap.put(client, saslServer);
+ }
+
+ byte[] response = saslServer.response(saslMessage.payload);
+ if (saslServer.isComplete()) {
+ logger.debug("SASL authentication successful for channel {}", client);
+ }
+ callback.onSuccess(response);
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return delegate.getStreamManager();
+ }
+
+ @Override
+ public void connectionTerminated(TransportClient client) {
+ SparkSaslServer saslServer = channelAuthenticationMap.remove(client);
+ if (saslServer != null) {
+ saslServer.dispose();
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
new file mode 100644
index 0000000000..81d5766794
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+/**
+ * Interface for getting a secret key associated with some application.
+ */
+public interface SecretKeyHolder {
+ /**
+ * Gets an appropriate SASL User for the given appId.
+ * @throws IllegalArgumentException if the given appId is not associated with a SASL user.
+ */
+ String getSaslUser(String appId);
+
+ /**
+ * Gets an appropriate SASL secret key for the given appId.
+ * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key.
+ */
+ String getSecretKey(String appId);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
new file mode 100644
index 0000000000..9abad1f30a
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.RealmChoiceCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+import java.io.IOException;
+
+import com.google.common.base.Throwables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.spark.network.sasl.SparkSaslServer.*;
+
+/**
+ * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the
+ * initial state to the "authenticated" state. This client initializes the protocol via a
+ * firstToken, which is then followed by a set of challenges and responses.
+ */
+public class SparkSaslClient {
+ private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class);
+
+ private final String secretKeyId;
+ private final SecretKeyHolder secretKeyHolder;
+ private SaslClient saslClient;
+
+ public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) {
+ this.secretKeyId = secretKeyId;
+ this.secretKeyHolder = secretKeyHolder;
+ try {
+ this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM,
+ SASL_PROPS, new ClientCallbackHandler());
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /** Used to initiate SASL handshake with server. */
+ public synchronized byte[] firstToken() {
+ if (saslClient != null && saslClient.hasInitialResponse()) {
+ try {
+ return saslClient.evaluateChallenge(new byte[0]);
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ } else {
+ return new byte[0];
+ }
+ }
+
+ /** Determines whether the authentication exchange has completed. */
+ public synchronized boolean isComplete() {
+ return saslClient != null && saslClient.isComplete();
+ }
+
+ /**
+ * Respond to server's SASL token.
+ * @param token contains server's SASL token
+ * @return client's response SASL token
+ */
+ public synchronized byte[] response(byte[] token) {
+ try {
+ return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0];
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslClient might be using.
+ */
+ public synchronized void dispose() {
+ if (saslClient != null) {
+ try {
+ saslClient.dispose();
+ } catch (SaslException e) {
+ // ignore
+ } finally {
+ saslClient = null;
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * that works with share secrets.
+ */
+ private class ClientCallbackHandler implements CallbackHandler {
+ @Override
+ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+
+ for (Callback callback : callbacks) {
+ if (callback instanceof NameCallback) {
+ logger.trace("SASL client callback: setting username");
+ NameCallback nc = (NameCallback) callback;
+ nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
+ } else if (callback instanceof PasswordCallback) {
+ logger.trace("SASL client callback: setting password");
+ PasswordCallback pc = (PasswordCallback) callback;
+ pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
+ } else if (callback instanceof RealmCallback) {
+ logger.trace("SASL client callback: setting realm");
+ RealmCallback rc = (RealmCallback) callback;
+ rc.setText(rc.getDefaultText());
+ } else if (callback instanceof RealmChoiceCallback) {
+ // ignore (?)
+ } else {
+ throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
+ }
+ }
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
new file mode 100644
index 0000000000..e87b17ead1
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.AuthorizeCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+import java.io.IOException;
+import java.util.Map;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableMap;
+import io.netty.buffer.Unpooled;
+import io.netty.handler.codec.base64.Base64;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the
+ * initial state to the "authenticated" state. (It is not a server in the sense of accepting
+ * connections on some socket.)
+ */
+public class SparkSaslServer {
+ private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class);
+
+ /**
+ * This is passed as the server name when creating the sasl client/server.
+ * This could be changed to be configurable in the future.
+ */
+ static final String DEFAULT_REALM = "default";
+
+ /**
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ static final String DIGEST = "DIGEST-MD5";
+
+ /**
+ * The quality of protection is just "auth". This means that we are doing
+ * authentication only, we are not supporting integrity or privacy protection of the
+ * communication channel after authentication. This could be changed to be configurable
+ * in the future.
+ */
+ static final Map<String, String> SASL_PROPS = ImmutableMap.<String, String>builder()
+ .put(Sasl.QOP, "auth")
+ .put(Sasl.SERVER_AUTH, "true")
+ .build();
+
+ /** Identifier for a certain secret key within the secretKeyHolder. */
+ private final String secretKeyId;
+ private final SecretKeyHolder secretKeyHolder;
+ private SaslServer saslServer;
+
+ public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) {
+ this.secretKeyId = secretKeyId;
+ this.secretKeyHolder = secretKeyHolder;
+ try {
+ this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS,
+ new DigestCallbackHandler());
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Determines whether the authentication exchange has completed successfully.
+ */
+ public synchronized boolean isComplete() {
+ return saslServer != null && saslServer.isComplete();
+ }
+
+ /**
+ * Used to respond to server SASL tokens.
+ * @param token Server's SASL token
+ * @return response to send back to the server.
+ */
+ public synchronized byte[] response(byte[] token) {
+ try {
+ return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0];
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslServer might be using.
+ */
+ public synchronized void dispose() {
+ if (saslServer != null) {
+ try {
+ saslServer.dispose();
+ } catch (SaslException e) {
+ // ignore
+ } finally {
+ saslServer = null;
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism.
+ */
+ private class DigestCallbackHandler implements CallbackHandler {
+ @Override
+ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+ for (Callback callback : callbacks) {
+ if (callback instanceof NameCallback) {
+ logger.trace("SASL server callback: setting username");
+ NameCallback nc = (NameCallback) callback;
+ nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
+ } else if (callback instanceof PasswordCallback) {
+ logger.trace("SASL server callback: setting password");
+ PasswordCallback pc = (PasswordCallback) callback;
+ pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
+ } else if (callback instanceof RealmCallback) {
+ logger.trace("SASL server callback: setting realm");
+ RealmCallback rc = (RealmCallback) callback;
+ rc.setText(rc.getDefaultText());
+ } else if (callback instanceof AuthorizeCallback) {
+ AuthorizeCallback ac = (AuthorizeCallback) callback;
+ String authId = ac.getAuthenticationID();
+ String authzId = ac.getAuthorizationID();
+ ac.setAuthorized(authId.equals(authzId));
+ if (ac.isAuthorized()) {
+ ac.setAuthorizedID(authzId);
+ }
+ logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized());
+ } else {
+ throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
+ }
+ }
+ }
+ }
+
+ /* Encode a byte[] identifier as a Base64-encoded string. */
+ public static String encodeIdentifier(String identifier) {
+ Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled");
+ return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8)))
+ .toString(Charsets.UTF_8);
+ }
+
+ /** Encode a password as a base64-encoded char[] array. */
+ public static char[] encodePassword(String password) {
+ Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled");
+ return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8)))
+ .toString(Charsets.UTF_8).toCharArray();
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
new file mode 100644
index 0000000000..23b4e06f06
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import org.junit.Test;
+
+
+/**
+ * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes.
+ */
+public class SparkSaslSuite {
+
+ /** Provides a secret key holder which returns secret key == appId */
+ private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() {
+ @Override
+ public String getSaslUser(String appId) {
+ return "user";
+ }
+
+ @Override
+ public String getSecretKey(String appId) {
+ return appId;
+ }
+ };
+
+ @Test
+ public void testMatching() {
+ SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder);
+ SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder);
+
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+
+ byte[] clientMessage = client.firstToken();
+
+ while (!client.isComplete()) {
+ clientMessage = client.response(server.response(clientMessage));
+ }
+ assertTrue(server.isComplete());
+
+ // Disposal should invalidate
+ server.dispose();
+ assertFalse(server.isComplete());
+ client.dispose();
+ assertFalse(client.isComplete());
+ }
+
+
+ @Test
+ public void testNonMatching() {
+ SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder);
+ SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder);
+
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+
+ byte[] clientMessage = client.firstToken();
+
+ try {
+ while (!client.isComplete()) {
+ clientMessage = client.response(server.response(clientMessage));
+ }
+ fail("Should not have completed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage().contains("Mismatched response"));
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+ }
+ }
+}