aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFerdinand Xu <cheng.a.xu@intel.com>2016-08-30 09:15:31 -0700
committerMarcelo Vanzin <vanzin@cloudera.com>2016-08-30 09:15:31 -0700
commit4b4e329e49f8af28fa6301bd06c48d7097eaf9e6 (patch)
tree91ec684d78a76de75097723f82537be3e01a1c28
parent27209252f09ff73c58e60c6df8aaba73b308088c (diff)
downloadspark-4b4e329e49f8af28fa6301bd06c48d7097eaf9e6.tar.gz
spark-4b4e329e49f8af28fa6301bd06c48d7097eaf9e6.tar.bz2
spark-4b4e329e49f8af28fa6301bd06c48d7097eaf9e6.zip
[SPARK-5682][CORE] Add encrypted shuffle in spark
This patch is using Apache Commons Crypto library to enable shuffle encryption support. Author: Ferdinand Xu <cheng.a.xu@intel.com> Author: kellyzly <kellyzly@126.com> Closes #8880 from winningsix/SPARK-10771.
-rw-r--r--core/pom.xml4
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java2
-rw-r--r--core/src/main/scala/org/apache/spark/SecurityManager.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/internal/config/package.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala109
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala47
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala6
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java4
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java4
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java4
-rw-r--r--core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala107
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala2
-rw-r--r--dev/deps/spark-deps-hadoop-2.21
-rw-r--r--dev/deps/spark-deps-hadoop-2.31
-rw-r--r--dev/deps/spark-deps-hadoop-2.41
-rw-r--r--dev/deps/spark-deps-hadoop-2.61
-rw-r--r--dev/deps/spark-deps-hadoop-2.71
-rw-r--r--docs/configuration.md23
-rw-r--r--pom.xml12
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala4
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala108
27 files changed, 478 insertions, 28 deletions
diff --git a/core/pom.xml b/core/pom.xml
index c04cf7e525..69a0b0ff27 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -327,6 +327,10 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-tags_${scala.binary.version}</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-crypto</artifactId>
+ </dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index d048cf7aeb..2875b0d69d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -72,7 +72,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
final BufferedInputStream bs =
new BufferedInputStream(new FileInputStream(file), (int) bufferSizeBytes);
try {
- this.in = serializerManager.wrapForCompression(blockId, bs);
+ this.in = serializerManager.wrapStream(blockId, bs);
this.din = new DataInputStream(this.in);
numRecords = numRecordsRemaining = din.readInt();
} catch (IOException e) {
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index a6550b6ca8..199365ad92 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -21,15 +21,19 @@ import java.lang.{Byte => JByte}
import java.net.{Authenticator, PasswordAuthentication}
import java.security.{KeyStore, SecureRandom}
import java.security.cert.X509Certificate
+import javax.crypto.KeyGenerator
import javax.net.ssl._
import com.google.common.hash.HashCodes
import com.google.common.io.Files
import org.apache.hadoop.io.Text
+import org.apache.hadoop.security.Credentials
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
import org.apache.spark.network.sasl.SecretKeyHolder
+import org.apache.spark.security.CryptoStreamUtils._
import org.apache.spark.util.Utils
/**
@@ -554,4 +558,20 @@ private[spark] object SecurityManager {
// key used to store the spark secret in the Hadoop UGI
val SECRET_LOOKUP_KEY = "sparkCookie"
+
+ /**
+ * Setup the cryptographic key used by IO encryption in credentials. The key is generated using
+ * [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]].
+ */
+ def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = {
+ if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) {
+ val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS)
+ val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM)
+ val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm)
+ keyGen.init(keyLen)
+
+ val ioKey = keyGen.generateKey()
+ credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded)
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 08d6343d62..744d5d0f7a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -49,6 +49,7 @@ import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat,
WholeTextFileInputFormat}
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
@@ -411,6 +412,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true")
+ if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) {
+ throw new SparkException("IO encryption is only supported in YARN mode, please disable it " +
+ s"by setting ${IO_ENCRYPTION_ENABLED.key} to false")
+ }
// "_jobProgressListener" should be set up before creating SparkEnv because when creating
// "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them.
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 47174e4efe..ebce07c1e3 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -119,4 +119,24 @@ package object config {
private[spark] val UI_RETAINED_TASKS = ConfigBuilder("spark.ui.retainedTasks")
.intConf
.createWithDefault(100000)
+
+ private[spark] val IO_ENCRYPTION_ENABLED = ConfigBuilder("spark.io.encryption.enabled")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val IO_ENCRYPTION_KEYGEN_ALGORITHM =
+ ConfigBuilder("spark.io.encryption.keygen.algorithm")
+ .stringConf
+ .createWithDefault("HmacSHA1")
+
+ private[spark] val IO_ENCRYPTION_KEY_SIZE_BITS = ConfigBuilder("spark.io.encryption.keySizeBits")
+ .intConf
+ .checkValues(Set(128, 192, 256))
+ .createWithDefault(128)
+
+ private[spark] val IO_CRYPTO_CIPHER_TRANSFORMATION =
+ ConfigBuilder("spark.io.crypto.cipher.transformation")
+ .internal()
+ .stringConf
+ .createWithDefaultString("AES/CTR/NoPadding")
}
diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
new file mode 100644
index 0000000000..8f15f50bee
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.security
+
+import java.io.{InputStream, OutputStream}
+import java.util.Properties
+import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}
+
+import org.apache.commons.crypto.random._
+import org.apache.commons.crypto.stream._
+import org.apache.hadoop.io.Text
+
+import org.apache.spark.SparkConf
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+
+/**
+ * A util class for manipulating IO encryption and decryption streams.
+ */
+private[spark] object CryptoStreamUtils extends Logging {
+ /**
+ * Constants and variables for spark IO encryption
+ */
+ val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN")
+
+ // The initialization vector length in bytes.
+ val IV_LENGTH_IN_BYTES = 16
+ // The prefix of IO encryption related configurations in Spark configuration.
+ val SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX = "spark.io.encryption.commons.config."
+ // The prefix for the configurations passing to Apache Commons Crypto library.
+ val COMMONS_CRYPTO_CONF_PREFIX = "commons.crypto."
+
+ /**
+ * Helper method to wrap [[OutputStream]] with [[CryptoOutputStream]] for encryption.
+ */
+ def createCryptoOutputStream(
+ os: OutputStream,
+ sparkConf: SparkConf): OutputStream = {
+ val properties = toCryptoConf(sparkConf)
+ val iv = createInitializationVector(properties)
+ os.write(iv)
+ val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
+ val key = credentials.getSecretKey(SPARK_IO_TOKEN)
+ val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
+ new CryptoOutputStream(transformationStr, properties, os,
+ new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
+ }
+
+ /**
+ * Helper method to wrap [[InputStream]] with [[CryptoInputStream]] for decryption.
+ */
+ def createCryptoInputStream(
+ is: InputStream,
+ sparkConf: SparkConf): InputStream = {
+ val properties = toCryptoConf(sparkConf)
+ val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
+ is.read(iv, 0, iv.length)
+ val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
+ val key = credentials.getSecretKey(SPARK_IO_TOKEN)
+ val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
+ new CryptoInputStream(transformationStr, properties, is,
+ new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
+ }
+
+ /**
+ * Get Commons-crypto configurations from Spark configurations identified by prefix.
+ */
+ def toCryptoConf(conf: SparkConf): Properties = {
+ val props = new Properties()
+ conf.getAll.foreach { case (k, v) =>
+ if (k.startsWith(SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX)) {
+ props.put(COMMONS_CRYPTO_CONF_PREFIX + k.substring(
+ SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.length()), v)
+ }
+ }
+ props
+ }
+
+ /**
+ * This method to generate an IV (Initialization Vector) using secure random.
+ */
+ private[this] def createInitializationVector(properties: Properties): Array[Byte] = {
+ val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
+ val initialIVStart = System.currentTimeMillis()
+ CryptoRandomFactory.getCryptoRandom(properties).nextBytes(iv)
+ val initialIVFinish = System.currentTimeMillis()
+ val initialIVTime = initialIVFinish - initialIVStart
+ if (initialIVTime > 2000) {
+ logWarning(s"It costs ${initialIVTime} milliseconds to create the Initialization Vector " +
+ s"used by CryptoStream")
+ }
+ iv
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 07caadbe40..7b1ec6fcbb 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -23,13 +23,15 @@ import java.nio.ByteBuffer
import scala.reflect.ClassTag
import org.apache.spark.SparkConf
+import org.apache.spark.internal.config._
import org.apache.spark.io.CompressionCodec
+import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.storage._
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
/**
- * Component which configures serialization and compression for various Spark components, including
- * automatic selection of which [[Serializer]] to use for shuffles.
+ * Component which configures serialization, compression and encryption for various Spark
+ * components, including automatic selection of which [[Serializer]] to use for shuffles.
*/
private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) {
@@ -61,6 +63,9 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
// Whether to compress shuffle output temporarily spilled to disk
private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)
+ // Whether to enable IO encryption
+ private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED)
+
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
* the initialization of the compression codec until it is first used. The reason is that a Spark
* program could be using a user-defined codec in a third party jar, which is loaded in
@@ -103,16 +108,44 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
}
/**
+ * Wrap an input stream for encryption and compression
+ */
+ def wrapStream(blockId: BlockId, s: InputStream): InputStream = {
+ wrapForCompression(blockId, wrapForEncryption(s))
+ }
+
+ /**
+ * Wrap an output stream for encryption and compression
+ */
+ def wrapStream(blockId: BlockId, s: OutputStream): OutputStream = {
+ wrapForCompression(blockId, wrapForEncryption(s))
+ }
+
+ /**
+ * Wrap an input stream for encryption if shuffle encryption is enabled
+ */
+ private[this] def wrapForEncryption(s: InputStream): InputStream = {
+ if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s
+ }
+
+ /**
+ * Wrap an output stream for encryption if shuffle encryption is enabled
+ */
+ private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
+ if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s
+ }
+
+ /**
* Wrap an output stream for compression if block compression is enabled for its block type
*/
- def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
+ private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
}
/**
* Wrap an input stream for compression if block compression is enabled for its block type
*/
- def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
+ private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}
@@ -123,7 +156,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
values: Iterator[T]): Unit = {
val byteStream = new BufferedOutputStream(outputStream)
val ser = getSerializer(implicitly[ClassTag[T]]).newInstance()
- ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close()
}
/** Serializes into a chunked byte buffer. */
@@ -139,7 +172,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
val byteStream = new BufferedOutputStream(bbos)
val ser = getSerializer(classTag).newInstance()
- ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close()
bbos.toChunkedByteBuffer
}
@@ -153,7 +186,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
val stream = new BufferedInputStream(inputStream)
getSerializer(implicitly[ClassTag[T]])
.newInstance()
- .deserializeStream(wrapForCompression(blockId, stream))
+ .deserializeStream(wrapStream(blockId, stream))
.asIterator.asInstanceOf[Iterator[T]]
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 5794f542b7..b9d83495d2 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -51,9 +51,9 @@ private[spark] class BlockStoreShuffleReader[K, C](
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
- // Wrap the streams for compression based on configuration
+ // Wrap the streams for compression and encryption based on configuration
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
- serializerManager.wrapForCompression(blockId, inputStream)
+ serializerManager.wrapStream(blockId, inputStream)
}
val serializerInstance = dep.serializer.newInstance()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index fe84652798..c72f28e00c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -721,10 +721,9 @@ private[spark] class BlockManager(
serializerInstance: SerializerInstance,
bufferSize: Int,
writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
- val compressStream: OutputStream => OutputStream =
- serializerManager.wrapForCompression(blockId, _)
+ val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
- new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream,
+ new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream,
syncWrites, writeMetrics, blockId)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index e5b1bf2f4b..a499827ae1 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -39,7 +39,7 @@ private[spark] class DiskBlockObjectWriter(
val file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
- compressStream: OutputStream => OutputStream,
+ wrapStream: OutputStream => OutputStream,
syncWrites: Boolean,
// These write metrics concurrently shared with other active DiskBlockObjectWriters who
// are themselves performing writes. All updates must be relative.
@@ -115,7 +115,8 @@ private[spark] class DiskBlockObjectWriter(
initialize()
initialized = true
}
- bs = compressStream(mcs)
+
+ bs = wrapStream(mcs)
objOut = serializerInstance.serializeStream(bs)
streamOpen = true
this
diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index 586339a58d..d220ab51d1 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -330,7 +330,7 @@ private[spark] class MemoryStore(
redirectableStream.setOutputStream(bbos)
val serializationStream: SerializationStream = {
val ser = serializerManager.getSerializer(classTag).newInstance()
- ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream))
+ ser.serializeStream(serializerManager.wrapStream(blockId, redirectableStream))
}
// Request enough memory to begin unrolling
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 8c8860bb37..0943528119 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -486,8 +486,8 @@ class ExternalAppendOnlyMap[K, V, C](
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
- val compressedStream = serializerManager.wrapForCompression(blockId, bufferedStream)
- ser.deserializeStream(compressedStream)
+ val wrappedStream = serializerManager.wrapStream(blockId, bufferedStream)
+ ser.deserializeStream(wrappedStream)
} else {
// No more batches left
cleanup()
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 7c98e8cabb..3579918fac 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -28,7 +28,6 @@ import com.google.common.io.ByteStreams
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.Logging
-import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.serializer._
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
@@ -522,8 +521,9 @@ private[spark] class ExternalSorter[K, V, C](
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
- val compressedStream = serializerManager.wrapForCompression(spill.blockId, bufferedStream)
- serInstance.deserializeStream(compressedStream)
+
+ val wrappedStream = serializerManager.wrapStream(spill.blockId, bufferedStream)
+ serInstance.deserializeStream(wrappedStream)
} else {
// No more batches left
cleanup()
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index daeb4675ea..a96cd82382 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -86,7 +86,7 @@ public class UnsafeShuffleWriterSuite {
@Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
@Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;
- private final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
+ private final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> {
@Override
public OutputStream apply(OutputStream stream) {
if (conf.getBoolean("spark.shuffle.compress", true)) {
@@ -136,7 +136,7 @@ public class UnsafeShuffleWriterSuite {
(File) args[1],
(SerializerInstance) args[2],
(Integer) args[3],
- new CompressStream(),
+ new WrapStream(),
false,
(ShuffleWriteMetrics) args[4],
(BlockId) args[0]
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index fc127f07c8..33709b454c 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -75,7 +75,7 @@ public abstract class AbstractBytesToBytesMapSuite {
@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
@Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
- private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
+ private static final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> {
@Override
public OutputStream apply(OutputStream stream) {
return stream;
@@ -122,7 +122,7 @@ public abstract class AbstractBytesToBytesMapSuite {
(File) args[1],
(SerializerInstance) args[2],
(Integer) args[3],
- new CompressStream(),
+ new WrapStream(),
false,
(ShuffleWriteMetrics) args[4],
(BlockId) args[0]
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 3ea99233fe..a9cf8ff520 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -88,7 +88,7 @@ public class UnsafeExternalSorterSuite {
private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m");
- private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
+ private static final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> {
@Override
public OutputStream apply(OutputStream stream) {
return stream;
@@ -128,7 +128,7 @@ public class UnsafeExternalSorterSuite {
(File) args[1],
(SerializerInstance) args[2],
(Integer) args[3],
- new CompressStream(),
+ new WrapStream(),
false,
(ShuffleWriteMetrics) args[4],
(BlockId) args[0]
diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
new file mode 100644
index 0000000000..81eb907ac7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.security
+
+import java.security.PrivilegedExceptionAction
+
+import org.apache.hadoop.security.{Credentials, UserGroupInformation}
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark.internal.config._
+import org.apache.spark.security.CryptoStreamUtils._
+
+class CryptoStreamUtilsSuite extends SparkFunSuite {
+ val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup"))
+
+ test("Crypto configuration conversion") {
+ val sparkKey1 = s"${SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX}a.b.c"
+ val sparkVal1 = "val1"
+ val cryptoKey1 = s"${COMMONS_CRYPTO_CONF_PREFIX}a.b.c"
+
+ val sparkKey2 = SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.stripSuffix(".") + "A.b.c"
+ val sparkVal2 = "val2"
+ val cryptoKey2 = s"${COMMONS_CRYPTO_CONF_PREFIX}A.b.c"
+ val conf = new SparkConf()
+ conf.set(sparkKey1, sparkVal1)
+ conf.set(sparkKey2, sparkVal2)
+ val props = CryptoStreamUtils.toCryptoConf(conf)
+ assert(props.getProperty(cryptoKey1) === sparkVal1)
+ assert(!props.containsKey(cryptoKey2))
+ }
+
+ test("Shuffle encryption is disabled by default") {
+ ugi.doAs(new PrivilegedExceptionAction[Unit]() {
+ override def run(): Unit = {
+ val credentials = UserGroupInformation.getCurrentUser.getCredentials()
+ val conf = new SparkConf()
+ initCredentials(conf, credentials)
+ assert(credentials.getSecretKey(SPARK_IO_TOKEN) === null)
+ }
+ })
+ }
+
+ test("Shuffle encryption key length should be 128 by default") {
+ ugi.doAs(new PrivilegedExceptionAction[Unit]() {
+ override def run(): Unit = {
+ val credentials = UserGroupInformation.getCurrentUser.getCredentials()
+ val conf = new SparkConf()
+ conf.set(IO_ENCRYPTION_ENABLED, true)
+ initCredentials(conf, credentials)
+ var key = credentials.getSecretKey(SPARK_IO_TOKEN)
+ assert(key !== null)
+ val actual = key.length * (java.lang.Byte.SIZE)
+ assert(actual === 128)
+ }
+ })
+ }
+
+ test("Initial credentials with key length in 256") {
+ ugi.doAs(new PrivilegedExceptionAction[Unit]() {
+ override def run(): Unit = {
+ val credentials = UserGroupInformation.getCurrentUser.getCredentials()
+ val conf = new SparkConf()
+ conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 256)
+ conf.set(IO_ENCRYPTION_ENABLED, true)
+ initCredentials(conf, credentials)
+ var key = credentials.getSecretKey(SPARK_IO_TOKEN)
+ assert(key !== null)
+ val actual = key.length * (java.lang.Byte.SIZE)
+ assert(actual === 256)
+ }
+ })
+ }
+
+ test("Initial credentials with invalid key length") {
+ ugi.doAs(new PrivilegedExceptionAction[Unit]() {
+ override def run(): Unit = {
+ val credentials = UserGroupInformation.getCurrentUser.getCredentials()
+ val conf = new SparkConf()
+ conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 328)
+ conf.set(IO_ENCRYPTION_ENABLED, true)
+ val thrown = intercept[IllegalArgumentException] {
+ initCredentials(conf, credentials)
+ }
+ }
+ })
+ }
+
+ private[this] def initCredentials(conf: SparkConf, credentials: Credentials): Unit = {
+ if (conf.get(IO_ENCRYPTION_ENABLED)) {
+ SecurityManager.initIOEncryptionKey(conf, credentials)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 5132384a5e..ed9428820f 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -94,7 +94,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
args(1).asInstanceOf[File],
args(2).asInstanceOf[SerializerInstance],
args(3).asInstanceOf[Int],
- compressStream = identity,
+ wrapStream = identity,
syncWrites = false,
args(4).asInstanceOf[ShuffleWriteMetrics],
blockId = args(0).asInstanceOf[BlockId]
diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2
index 326271a7e2..eaed0889ac 100644
--- a/dev/deps/spark-deps-hadoop-2.2
+++ b/dev/deps/spark-deps-hadoop-2.2
@@ -27,6 +27,7 @@ commons-collections-3.2.2.jar
commons-compiler-2.7.6.jar
commons-compress-1.4.1.jar
commons-configuration-1.6.jar
+commons-crypto-1.0.0.jar
commons-dbcp-1.4.jar
commons-digester-1.8.jar
commons-httpclient-3.1.jar
diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3
index 1ff6ecb734..d68a7f462b 100644
--- a/dev/deps/spark-deps-hadoop-2.3
+++ b/dev/deps/spark-deps-hadoop-2.3
@@ -30,6 +30,7 @@ commons-collections-3.2.2.jar
commons-compiler-2.7.6.jar
commons-compress-1.4.1.jar
commons-configuration-1.6.jar
+commons-crypto-1.0.0.jar
commons-dbcp-1.4.jar
commons-digester-1.8.jar
commons-httpclient-3.1.jar
diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4
index 68333849cf..346f19767d 100644
--- a/dev/deps/spark-deps-hadoop-2.4
+++ b/dev/deps/spark-deps-hadoop-2.4
@@ -30,6 +30,7 @@ commons-collections-3.2.2.jar
commons-compiler-2.7.6.jar
commons-compress-1.4.1.jar
commons-configuration-1.6.jar
+commons-crypto-1.0.0.jar
commons-dbcp-1.4.jar
commons-digester-1.8.jar
commons-httpclient-3.1.jar
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index 787d06c351..6f4695f345 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -34,6 +34,7 @@ commons-collections-3.2.2.jar
commons-compiler-2.7.6.jar
commons-compress-1.4.1.jar
commons-configuration-1.6.jar
+commons-crypto-1.0.0.jar
commons-dbcp-1.4.jar
commons-digester-1.8.jar
commons-httpclient-3.1.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index 386495bf1b..7a86a8bd88 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -34,6 +34,7 @@ commons-collections-3.2.2.jar
commons-compiler-2.7.6.jar
commons-compress-1.4.1.jar
commons-configuration-1.6.jar
+commons-crypto-1.0.0.jar
commons-dbcp-1.4.jar
commons-digester-1.8.jar
commons-httpclient-3.1.jar
diff --git a/docs/configuration.md b/docs/configuration.md
index 2f80196105..d0c76aaad0 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -559,6 +559,29 @@ Apart from these, the following properties are also available, and may be useful
<code>spark.io.compression.codec</code>.
</td>
</tr>
+<tr>
+ <td><code>spark.io.encryption.enabled</code></td>
+ <td>false</td>
+ <td>
+ Enable IO encryption. Only supported in YARN mode.
+ </td>
+</tr>
+<tr>
+ <td><code>spark.io.encryption.keySizeBits</code></td>
+ <td>128</td>
+ <td>
+ IO encryption key size in bits. Supported values are 128, 192 and 256.
+ </td>
+</tr>
+<tr>
+ <td><code>spark.io.encryption.keygen.algorithm</code></td>
+ <td>HmacSHA1</td>
+ <td>
+ The algorithm to use when generating the IO encryption key. The supported algorithms are
+ described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm
+ Name Documentation.
+ </td>
+</tr>
</table>
#### Spark UI
diff --git a/pom.xml b/pom.xml
index 74238db59e..2c265c1fa3 100644
--- a/pom.xml
+++ b/pom.xml
@@ -180,6 +180,7 @@
<selenium.version>2.52.0</selenium.version>
<paranamer.version>2.8</paranamer.version>
<maven-antrun.version>1.8</maven-antrun.version>
+ <commons-crypto.version>1.0.0</commons-crypto.version>
<test.java.home>${java.home}</test.java.home>
<test.exclude.tags></test.exclude.tags>
@@ -1825,6 +1826,17 @@
<artifactId>jline</artifactId>
<version>${jline.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-crypto</artifactId>
+ <version>${commons-crypto.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>net.java.dev.jna</groupId>
+ <artifactId>jna</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
</dependencies>
</dependencyManagement>
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 7fbbe91de9..2398f0aea3 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -1003,6 +1003,10 @@ private[spark] class Client(
val securityManager = new SecurityManager(sparkConf)
amContainer.setApplicationACLs(
YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava)
+
+ if (sparkConf.get(IO_ENCRYPTION_ENABLED)) {
+ SecurityManager.initIOEncryptionKey(sparkConf, credentials)
+ }
setupSecurityToken(amContainer)
amContainer
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala
new file mode 100644
index 0000000000..1c60315b21
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.yarn
+
+import java.io._
+import java.nio.charset.StandardCharsets
+import java.security.PrivilegedExceptionAction
+import java.util.UUID
+
+import org.apache.hadoop.security.{Credentials, UserGroupInformation}
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers}
+
+import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.config._
+import org.apache.spark.serializer._
+import org.apache.spark.storage._
+
+class IOEncryptionSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
+ with BeforeAndAfterEach {
+ private[this] val blockId = new TempShuffleBlockId(UUID.randomUUID())
+ private[this] val conf = new SparkConf()
+ private[this] val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup"))
+ private[this] val serializer = new KryoSerializer(conf)
+
+ override def beforeAll(): Unit = {
+ System.setProperty("SPARK_YARN_MODE", "true")
+ ugi.doAs(new PrivilegedExceptionAction[Unit]() {
+ override def run(): Unit = {
+ conf.set(IO_ENCRYPTION_ENABLED, true)
+ val creds = new Credentials()
+ SecurityManager.initIOEncryptionKey(conf, creds)
+ SparkHadoopUtil.get.addCurrentUserCredentials(creds)
+ }
+ })
+ }
+
+ override def afterAll(): Unit = {
+ SparkEnv.set(null)
+ System.clearProperty("SPARK_YARN_MODE")
+ }
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ }
+
+ override def afterEach(): Unit = {
+ super.afterEach()
+ conf.set("spark.shuffle.compress", false.toString)
+ conf.set("spark.shuffle.spill.compress", false.toString)
+ }
+
+ test("IO encryption read and write") {
+ ugi.doAs(new PrivilegedExceptionAction[Unit] {
+ override def run(): Unit = {
+ conf.set(IO_ENCRYPTION_ENABLED, true)
+ conf.set("spark.shuffle.compress", false.toString)
+ conf.set("spark.shuffle.spill.compress", false.toString)
+ testYarnIOEncryptionWriteRead()
+ }
+ })
+ }
+
+ test("IO encryption read and write with shuffle compression enabled") {
+ ugi.doAs(new PrivilegedExceptionAction[Unit] {
+ override def run(): Unit = {
+ conf.set(IO_ENCRYPTION_ENABLED, true)
+ conf.set("spark.shuffle.compress", true.toString)
+ conf.set("spark.shuffle.spill.compress", true.toString)
+ testYarnIOEncryptionWriteRead()
+ }
+ })
+ }
+
+ private[this] def testYarnIOEncryptionWriteRead(): Unit = {
+ val plainStr = "hello world"
+ val outputStream = new ByteArrayOutputStream()
+ val serializerManager = new SerializerManager(serializer, conf)
+ val wrappedOutputStream = serializerManager.wrapStream(blockId, outputStream)
+ wrappedOutputStream.write(plainStr.getBytes(StandardCharsets.UTF_8))
+ wrappedOutputStream.close()
+
+ val encryptedBytes = outputStream.toByteArray
+ val encryptedStr = new String(encryptedBytes)
+ assert(plainStr !== encryptedStr)
+
+ val inputStream = new ByteArrayInputStream(encryptedBytes)
+ val wrappedInputStream = serializerManager.wrapStream(blockId, inputStream)
+ val decryptedBytes = new Array[Byte](1024)
+ val len = wrappedInputStream.read(decryptedBytes)
+ val decryptedStr = new String(decryptedBytes, 0, len, StandardCharsets.UTF_8)
+ assert(decryptedStr === plainStr)
+ }
+}