diff options
author | Aaron Davidson <aaron@databricks.com> | 2014-11-05 14:38:43 -0800 |
---|---|---|
committer | Patrick Wendell <pwendell@gmail.com> | 2014-11-05 14:38:43 -0800 |
commit | 4c42986cc070d9c5c55c7bf8a2a67585967b1082 (patch) | |
tree | 0c20263f4d5b7cca3be13e3f9a160e2eb8014a63 /network | |
parent | 5b3b6f6f5f029164d7749366506e142b104c1d43 (diff) | |
download | spark-4c42986cc070d9c5c55c7bf8a2a67585967b1082.tar.gz spark-4c42986cc070d9c5c55c7bf8a2a67585967b1082.tar.bz2 spark-4c42986cc070d9c5c55c7bf8a2a67585967b1082.zip |
[SPARK-4242] [Core] Add SASL to external shuffle service
Does three things: (1) Adds SASL to ExternalShuffleClient, (2) puts SecurityManager in BlockManager's constructor, and (3) adds unit test.
Author: Aaron Davidson <aaron@databricks.com>
Closes #3108 from aarondav/sasl-client and squashes the following commits:
48b622d [Aaron Davidson] Screw it, let's just get LimitedInputStream
3543b70 [Aaron Davidson] Back out of pom change due to unknown test issue?
b58518a [Aaron Davidson] ByteStreams.limit() not available :(
cbe451a [Aaron Davidson] Address comments
2bf2908 [Aaron Davidson] [SPARK-4242] [Core] Add SASL to external shuffle service
Diffstat (limited to 'network')
9 files changed, 239 insertions, 11 deletions
diff --git a/network/common/pom.xml b/network/common/pom.xml index ea887148d9..6144548a8f 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -50,6 +50,7 @@ <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> + <version>11.0.2</version> <!-- yarn 2.4.0's version --> <scope>provided</scope> </dependency> diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 89ed79bc63..5fa1527ddf 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -30,6 +30,7 @@ import com.google.common.io.ByteStreams; import io.netty.channel.DefaultFileRegion; import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.LimitedInputStream; /** * A {@link ManagedBuffer} backed by a segment in a file. @@ -101,7 +102,7 @@ public final class FileSegmentManagedBuffer extends ManagedBuffer { try { is = new FileInputStream(file); ByteStreams.skipFully(is, offset); - return ByteStreams.limit(is, length); + return new LimitedInputStream(is, length); } catch (IOException e) { try { if (is != null) { diff --git a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java new file mode 100644 index 0000000000..63ca43c046 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; + +import com.google.common.base.Preconditions; + +/** + * Wraps a {@link InputStream}, limiting the number of bytes which can be read. + * + * This code is from Guava's 14.0 source code, because there is no compatible way to + * use this functionality in both a Guava 11 environment and a Guava >14 environment. + */ +public final class LimitedInputStream extends FilterInputStream { + private long left; + private long mark = -1; + + public LimitedInputStream(InputStream in, long limit) { + super(in); + Preconditions.checkNotNull(in); + Preconditions.checkArgument(limit >= 0, "limit must be non-negative"); + left = limit; + } + @Override public int available() throws IOException { + return (int) Math.min(in.available(), left); + } + // it's okay to mark even if mark isn't supported, as reset won't work + @Override public synchronized void mark(int readLimit) { + in.mark(readLimit); + mark = left; + } + @Override public int read() throws IOException { + if (left == 0) { + return -1; + } + int result = in.read(); + if (result != -1) { + --left; + } + return result; + } + @Override public int read(byte[] b, int off, int len) throws IOException { + if (left == 0) { + return -1; + } + len = (int) Math.min(len, left); + int result = in.read(b, off, len); + if (result != -1) { + left -= result; + } + return result; + } + @Override public synchronized void reset() throws IOException { + if (!in.markSupported()) { + throw new IOException("Mark not supported"); + } + if (mark == -1) { + throw new IOException("Mark not set"); + } + in.reset(); + left = mark; + } + @Override public long skip(long n) throws IOException { + n = Math.min(n, left); + long skipped = in.skip(n); + left -= skipped; + return skipped; + } +} diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index d271704d98..fe5681d463 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -51,6 +51,7 @@ <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> + <version>11.0.2</version> <!-- yarn 2.4.0's version --> <scope>provided</scope> </dependency> diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java index 72ba737b99..9abad1f30a 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -126,7 +126,6 @@ public class SparkSaslClient { logger.trace("SASL client callback: setting realm"); RealmCallback rc = (RealmCallback) callback; rc.setText(rc.getDefaultText()); - logger.info("Realm callback"); } else if (callback instanceof RealmChoiceCallback) { // ignore (?) } else { diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index 2c0ce40c75..e87b17ead1 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -34,7 +34,8 @@ import com.google.common.base.Charsets; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; -import com.google.common.io.BaseEncoding; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -159,12 +160,14 @@ public class SparkSaslServer { /* Encode a byte[] identifier as a Base64-encoded string. */ public static String encodeIdentifier(String identifier) { Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); - return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8)); + return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8); } /** Encode a password as a base64-encoded char[] array. */ public static char[] encodePassword(String password) { Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); - return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray(); + return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8).toCharArray(); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index b0b19ba67b..3aa95d00f6 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,12 +17,18 @@ package org.apache.spark.network.shuffle; +import java.util.List; + +import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.sasl.SaslClientBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; import org.apache.spark.network.util.JavaUtils; @@ -37,18 +43,35 @@ import org.apache.spark.network.util.TransportConf; public class ExternalShuffleClient extends ShuffleClient { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); - private final TransportClientFactory clientFactory; + private final TransportConf conf; + private final boolean saslEnabled; + private final SecretKeyHolder secretKeyHolder; + private TransportClientFactory clientFactory; private String appId; - public ExternalShuffleClient(TransportConf conf) { - TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); - this.clientFactory = context.createClientFactory(); + /** + * Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled, + * then secretKeyHolder may be null. + */ + public ExternalShuffleClient( + TransportConf conf, + SecretKeyHolder secretKeyHolder, + boolean saslEnabled) { + this.conf = conf; + this.secretKeyHolder = secretKeyHolder; + this.saslEnabled = saslEnabled; } @Override public void init(String appId) { this.appId = appId; + TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + List<TransportClientBootstrap> bootstraps = Lists.newArrayList(); + if (saslEnabled) { + bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder)); + } + clientFactory = context.createClientFactory(bootstraps); } @Override diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index bc101f5384..71e017b9e4 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -135,7 +135,7 @@ public class ExternalShuffleIntegrationSuite { final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @@ -267,7 +267,7 @@ public class ExternalShuffleIntegrationSuite { } private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { - ExternalShuffleClient client = new ExternalShuffleClient(conf); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java new file mode 100644 index 0000000000..4c18fcdfbc --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleSecuritySuite { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportServer server; + + @Before + public void beforeEach() { + RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(), + new TestSecretKeyHolder("my-app-id", "secret")); + TransportContext context = new TransportContext(conf, handler); + this.server = context.createServer(); + } + + @After + public void afterEach() { + if (server != null) { + server.close(); + server = null; + } + } + + @Test + public void testValid() { + validate("my-app-id", "secret"); + } + + @Test + public void testBadAppId() { + try { + validate("wrong-app-id", "secret"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!")); + } + } + + @Test + public void testBadSecret() { + try { + validate("my-app-id", "bad-secret"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + /** Creates an ExternalShuffleClient and attempts to register with the server. */ + private void validate(String appId, String secretKey) { + ExternalShuffleClient client = + new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true); + client.init(appId); + // Registration either succeeds or throws an exception. + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", + new ExecutorShuffleInfo(new String[0], 0, "")); + client.close(); + } + + /** Provides a secret key holder which always returns the given secret key, for a single appId. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + private final String appId; + private final String secretKey; + + TestSecretKeyHolder(String appId, String secretKey) { + this.appId = appId; + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + if (!appId.equals(this.appId)) { + throw new IllegalArgumentException("Wrong appId!"); + } + return secretKey; + } + } +} |