aboutsummaryrefslogtreecommitdiff
path: root/network/common
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2015-09-02 12:53:24 -0700
committerMarcelo Vanzin <vanzin@cloudera.com>2015-09-02 12:53:24 -0700
commit2da3a9e98e5d129d4507b5db01bba5ee9558d28e (patch)
treec5197f543f18959d793db1caea4ee553acef4f97 /network/common
parentfc48307797912dc1d53893dce741ddda8630957b (diff)
downloadspark-2da3a9e98e5d129d4507b5db01bba5ee9558d28e.tar.gz
spark-2da3a9e98e5d129d4507b5db01bba5ee9558d28e.tar.bz2
spark-2da3a9e98e5d129d4507b5db01bba5ee9558d28e.zip
[SPARK-10004] [SHUFFLE] Perform auth checks when clients read shuffle data.
To correctly isolate applications, when requests to read shuffle data arrive at the shuffle service, proper authorization checks need to be performed. This change makes sure that only the application that created the shuffle data can read from it. Such checks are only enabled when "spark.authenticate" is enabled, otherwise there's no secure way to make sure that the client is really who it says it is. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #8218 from vanzin/SPARK-10004.
Diffstat (limited to 'network/common')
-rw-r--r--network/common/pom.xml4
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java22
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java1
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java31
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/StreamManager.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java1
7 files changed, 65 insertions, 5 deletions
diff --git a/network/common/pom.xml b/network/common/pom.xml
index 7dc3068ab8..4141fcb826 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -48,6 +48,10 @@
<artifactId>slf4j-api</artifactId>
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>com.google.code.findbugs</groupId>
+ <artifactId>jsr305</artifactId>
+ </dependency>
<!--
Promote Guava to "compile" so that maven-shade-plugin picks it up (for packaging the Optional
class exposed in the Java API). The plugin will then remove this dependency from the published
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index e8e7f06247..df841288a0 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -23,6 +23,7 @@ import java.net.SocketAddress;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
@@ -70,6 +71,7 @@ public class TransportClient implements Closeable {
private final Channel channel;
private final TransportResponseHandler handler;
+ @Nullable private String clientId;
public TransportClient(Channel channel, TransportResponseHandler handler) {
this.channel = Preconditions.checkNotNull(channel);
@@ -85,6 +87,25 @@ public class TransportClient implements Closeable {
}
/**
+ * Returns the ID used by the client to authenticate itself when authentication is enabled.
+ *
+ * @return The client ID, or null if authentication is disabled.
+ */
+ public String getClientId() {
+ return clientId;
+ }
+
+ /**
+ * Sets the authenticated client ID. This is meant to be used by the authentication layer.
+ *
+ * Trying to set a different client ID after it's been set will result in an exception.
+ */
+ public void setClientId(String id) {
+ Preconditions.checkState(clientId == null, "Client ID has already been set.");
+ this.clientId = id;
+ }
+
+ /**
* Requests a single chunk from the remote side, from the pre-negotiated streamId.
*
* Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though
@@ -207,6 +228,7 @@ public class TransportClient implements Closeable {
public String toString() {
return Objects.toStringHelper(this)
.add("remoteAdress", channel.remoteAddress())
+ .add("clientId", clientId)
.add("isActive", isActive())
.toString();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index 185ba2ef3b..69923769d4 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -77,6 +77,8 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
payload = saslClient.response(response);
}
+ client.setClientId(appId);
+
if (encrypt) {
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
throw new RuntimeException(
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index be6165caf3..3f2ebe3288 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -81,6 +81,7 @@ class SaslRpcHandler extends RpcHandler {
if (saslServer == null) {
// First message in the handshake, setup the necessary state.
+ client.setClientId(saslMessage.appId);
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
conf.saslServerAlwaysEncrypt());
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index c95e64e8e2..e671854da1 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -24,13 +24,13 @@ import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
+import com.google.common.base.Preconditions;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer;
-
-import com.google.common.base.Preconditions;
+import org.apache.spark.network.client.TransportClient;
/**
* StreamManager which allows registration of an Iterator&lt;ManagedBuffer&gt;, which are individually
@@ -44,6 +44,7 @@ public class OneForOneStreamManager extends StreamManager {
/** State of a single stream. */
private static class StreamState {
+ final String appId;
final Iterator<ManagedBuffer> buffers;
// The channel associated to the stream
@@ -53,7 +54,8 @@ public class OneForOneStreamManager extends StreamManager {
// that the caller only requests each chunk one at a time, in order.
int curChunk = 0;
- StreamState(Iterator<ManagedBuffer> buffers) {
+ StreamState(String appId, Iterator<ManagedBuffer> buffers) {
+ this.appId = appId;
this.buffers = Preconditions.checkNotNull(buffers);
}
}
@@ -109,15 +111,34 @@ public class OneForOneStreamManager extends StreamManager {
}
}
+ @Override
+ public void checkAuthorization(TransportClient client, long streamId) {
+ if (client.getClientId() != null) {
+ StreamState state = streams.get(streamId);
+ Preconditions.checkArgument(state != null, "Unknown stream ID.");
+ if (!client.getClientId().equals(state.appId)) {
+ throw new SecurityException(String.format(
+ "Client %s not authorized to read stream %d (app %s).",
+ client.getClientId(),
+ streamId,
+ state.appId));
+ }
+ }
+ }
+
/**
* Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
* callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a
* client connection is closed before the iterator is fully drained, then the remaining buffers
* will all be release()'d.
+ *
+ * If an app ID is provided, only callers who've authenticated with the given app ID will be
+ * allowed to fetch from this stream.
*/
- public long registerStream(Iterator<ManagedBuffer> buffers) {
+ public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
long myStreamId = nextStreamId.getAndIncrement();
- streams.put(myStreamId, new StreamState(buffers));
+ streams.put(myStreamId, new StreamState(appId, buffers));
return myStreamId;
}
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
index 929f789bf9..aaa677c965 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -20,6 +20,7 @@ package org.apache.spark.network.server;
import io.netty.channel.Channel;
import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.TransportClient;
/**
* The StreamManager is used to fetch individual chunks from a stream. This is used in
@@ -60,4 +61,12 @@ public abstract class StreamManager {
* to read from the associated streams again, so any state can be cleaned up.
*/
public void connectionTerminated(Channel channel) { }
+
+ /**
+ * Verify that the client is authorized to read from the given stream.
+ *
+ * @throws SecurityException If client is not authorized.
+ */
+ public void checkAuthorization(TransportClient client, long streamId) { }
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index e5159ab56d..df60278058 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -97,6 +97,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
ManagedBuffer buf;
try {
+ streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
streamManager.registerChannel(channel, req.streamChunkId.streamId);
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
} catch (Exception e) {