diff options
author | Marcelo Vanzin <vanzin@cloudera.com> | 2015-03-11 13:16:22 +0000 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2015-03-11 13:16:22 +0000 |
commit | 5b335bdda3efb7c6a5b18b4eeff189064c11e6c3 (patch) | |
tree | 6d891475ad6fe0466a38a7032f16bb21fd8d0c47 /network/common | |
parent | 6e94c4eadf443ac3d34eaae4c334c8386fdec960 (diff) | |
download | spark-5b335bdda3efb7c6a5b18b4eeff189064c11e6c3.tar.gz spark-5b335bdda3efb7c6a5b18b4eeff189064c11e6c3.tar.bz2 spark-5b335bdda3efb7c6a5b18b4eeff189064c11e6c3.zip |
[SPARK-6228] [network] Move SASL classes from network/shuffle to network...
.../common.
No code changes. Left the shuffle-related files in the shuffle module.
Author: Marcelo Vanzin <vanzin@cloudera.com>
Closes #4953 from vanzin/SPARK-6228 and squashes the following commits:
664ef30 [Marcelo Vanzin] [SPARK-6228] [network] Move SASL classes from network/shuffle to network/common.
Diffstat (limited to 'network/common')
7 files changed, 667 insertions, 0 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java new file mode 100644 index 0000000000..33aa134434 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.TransportConf; + +/** + * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The + * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId. + */ +public class SaslClientBootstrap implements TransportClientBootstrap { + private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); + + private final TransportConf conf; + private final String appId; + private final SecretKeyHolder secretKeyHolder; + + public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.appId = appId; + this.secretKeyHolder = secretKeyHolder; + } + + /** + * Performs SASL authentication by sending a token, and then proceeding with the SASL + * challenge-response tokens until we either successfully authenticate or throw an exception + * due to mismatch. + */ + @Override + public void doBootstrap(TransportClient client) { + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder); + try { + byte[] payload = saslClient.firstToken(); + + while (!saslClient.isComplete()) { + SaslMessage msg = new SaslMessage(appId, payload); + ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + msg.encode(buf); + + byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs()); + payload = saslClient.response(response); + } + } finally { + try { + // Once authentication is complete, the server will trust all remaining communication. + saslClient.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL client", e); + } + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java new file mode 100644 index 0000000000..cad76ab7aa --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** + * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged + * with the given appId. This appId allows a single SaslRpcHandler to multiplex different + * applications which may be using different sets of credentials. + */ +class SaslMessage implements Encodable { + + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xEA; + + public final String appId; + public final byte[] payload; + + public SaslMessage(String appId, byte[] payload) { + this.appId = appId; + this.payload = payload; + } + + @Override + public int encodedLength() { + return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.Strings.encode(buf, appId); + Encoders.ByteArrays.encode(buf, payload); + } + + public static SaslMessage decode(ByteBuf buf) { + if (buf.readByte() != TAG_BYTE) { + throw new IllegalStateException("Expected SaslMessage, received something else" + + " (maybe your client does not have SASL enabled?)"); + } + + String appId = Encoders.Strings.decode(buf); + byte[] payload = Encoders.ByteArrays.decode(buf); + return new SaslMessage(appId, payload); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java new file mode 100644 index 0000000000..026cbd260d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.concurrent.ConcurrentMap; + +import com.google.common.collect.Maps; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; + +/** + * RPC Handler which performs SASL authentication before delegating to a child RPC handler. + * The delegate will only receive messages if the given connection has been successfully + * authenticated. A connection may be authenticated at most once. + * + * Note that the authentication process consists of multiple challenge-response pairs, each of + * which are individual RPCs. + */ +public class SaslRpcHandler extends RpcHandler { + private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); + + /** RpcHandler we will delegate to for authenticated connections. */ + private final RpcHandler delegate; + + /** Class which provides secret keys which are shared by server and client on a per-app basis. */ + private final SecretKeyHolder secretKeyHolder; + + /** Maps each channel to its SASL authentication state. */ + private final ConcurrentMap<TransportClient, SparkSaslServer> channelAuthenticationMap; + + public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { + this.delegate = delegate; + this.secretKeyHolder = secretKeyHolder; + this.channelAuthenticationMap = Maps.newConcurrentMap(); + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + SparkSaslServer saslServer = channelAuthenticationMap.get(client); + if (saslServer != null && saslServer.isComplete()) { + // Authentication complete, delegate to base handler. + delegate.receive(client, message, callback); + return; + } + + SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + + if (saslServer == null) { + // First message in the handshake, setup the necessary state. + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder); + channelAuthenticationMap.put(client, saslServer); + } + + byte[] response = saslServer.response(saslMessage.payload); + if (saslServer.isComplete()) { + logger.debug("SASL authentication successful for channel {}", client); + } + callback.onSuccess(response); + } + + @Override + public StreamManager getStreamManager() { + return delegate.getStreamManager(); + } + + @Override + public void connectionTerminated(TransportClient client) { + SparkSaslServer saslServer = channelAuthenticationMap.remove(client); + if (saslServer != null) { + saslServer.dispose(); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java new file mode 100644 index 0000000000..81d5766794 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +/** + * Interface for getting a secret key associated with some application. + */ +public interface SecretKeyHolder { + /** + * Gets an appropriate SASL User for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL user. + */ + String getSaslUser(String appId); + + /** + * Gets an appropriate SASL secret key for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key. + */ + String getSecretKey(String appId); +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java new file mode 100644 index 0000000000..9abad1f30a --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.RealmChoiceCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import java.io.IOException; + +import com.google.common.base.Throwables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.spark.network.sasl.SparkSaslServer.*; + +/** + * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. This client initializes the protocol via a + * firstToken, which is then followed by a set of challenges and responses. + */ +public class SparkSaslClient { + private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class); + + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslClient saslClient; + + public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, + SASL_PROPS, new ClientCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** Used to initiate SASL handshake with server. */ + public synchronized byte[] firstToken() { + if (saslClient != null && saslClient.hasInitialResponse()) { + try { + return saslClient.evaluateChallenge(new byte[0]); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } else { + return new byte[0]; + } + } + + /** Determines whether the authentication exchange has completed. */ + public synchronized boolean isComplete() { + return saslClient != null && saslClient.isComplete(); + } + + /** + * Respond to server's SASL token. + * @param token contains server's SASL token + * @return client's response SASL token + */ + public synchronized byte[] response(byte[] token) { + try { + return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslClient might be using. + */ + public synchronized void dispose() { + if (saslClient != null) { + try { + saslClient.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslClient = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * that works with share secrets. + */ + private class ClientCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL client callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL client callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL client callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof RealmChoiceCallback) { + // ignore (?) + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java new file mode 100644 index 0000000000..e87b17ead1 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import java.io.IOException; +import java.util.Map; + +import com.google.common.base.Charsets; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. (It is not a server in the sense of accepting + * connections on some socket.) + */ +public class SparkSaslServer { + private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class); + + /** + * This is passed as the server name when creating the sasl client/server. + * This could be changed to be configurable in the future. + */ + static final String DEFAULT_REALM = "default"; + + /** + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + static final String DIGEST = "DIGEST-MD5"; + + /** + * The quality of protection is just "auth". This means that we are doing + * authentication only, we are not supporting integrity or privacy protection of the + * communication channel after authentication. This could be changed to be configurable + * in the future. + */ + static final Map<String, String> SASL_PROPS = ImmutableMap.<String, String>builder() + .put(Sasl.QOP, "auth") + .put(Sasl.SERVER_AUTH, "true") + .build(); + + /** Identifier for a certain secret key within the secretKeyHolder. */ + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslServer saslServer; + + public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS, + new DigestCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Determines whether the authentication exchange has completed successfully. + */ + public synchronized boolean isComplete() { + return saslServer != null && saslServer.isComplete(); + } + + /** + * Used to respond to server SASL tokens. + * @param token Server's SASL token + * @return response to send back to the server. + */ + public synchronized byte[] response(byte[] token) { + try { + return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslServer might be using. + */ + public synchronized void dispose() { + if (saslServer != null) { + try { + saslServer.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslServer = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism. + */ + private class DigestCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL server callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL server callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL server callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof AuthorizeCallback) { + AuthorizeCallback ac = (AuthorizeCallback) callback; + String authId = ac.getAuthenticationID(); + String authzId = ac.getAuthorizationID(); + ac.setAuthorized(authId.equals(authzId)); + if (ac.isAuthorized()) { + ac.setAuthorizedID(authzId); + } + logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized()); + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } + + /* Encode a byte[] identifier as a Base64-encoded string. */ + public static String encodeIdentifier(String identifier) { + Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); + return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8); + } + + /** Encode a password as a base64-encoded char[] array. */ + public static char[] encodePassword(String password) { + Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); + return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8).toCharArray(); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java new file mode 100644 index 0000000000..23b4e06f06 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.junit.Test; + + +/** + * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. + */ +public class SparkSaslSuite { + + /** Provides a secret key holder which returns secret key == appId */ + private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() { + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + return appId; + } + }; + + @Test + public void testMatching() { + SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + assertTrue(server.isComplete()); + + // Disposal should invalidate + server.dispose(); + assertFalse(server.isComplete()); + client.dispose(); + assertFalse(client.isComplete()); + } + + + @Test + public void testNonMatching() { + SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + try { + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + fail("Should not have completed"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("Mismatched response")); + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + } + } +} |