aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2016-11-28 21:10:57 -0800
committerShixiong Zhu <shixiong@databricks.com>2016-11-28 21:10:57 -0800
commit8b325b17ecdf013b7a6edcb7ee3773546bd914df (patch)
treee2826f751402537582646f88fe3b905783fa2f7e /core/src
parent1633ff3b6c97e33191859f34c868782cbb0972fd (diff)
downloadspark-8b325b17ecdf013b7a6edcb7ee3773546bd914df.tar.gz
spark-8b325b17ecdf013b7a6edcb7ee3773546bd914df.tar.bz2
spark-8b325b17ecdf013b7a6edcb7ee3773546bd914df.zip
[SPARK-18547][CORE] Propagate I/O encryption key when executors register.
This change modifies the method used to propagate encryption keys used during shuffle. Instead of relying on YARN's UserGroupInformation credential propagation, this change explicitly distributes the key using the messages exchanged between driver and executor during registration. When RPC encryption is enabled, this means key propagation is also secure. This allows shuffle encryption to work in non-YARN mode, which means that it's easier to write unit tests for areas of the code that are affected by the feature. The key is stored in the SecurityManager; because there are many instances of that class used in the code, the key is only guaranteed to exist in the instance managed by the SparkEnv. This path was chosen to avoid storing the key in the SparkConf, which would risk having the key being written to disk as part of the configuration (as, for example, is done when starting YARN applications). Tested by new and existing unit tests (which were moved from the YARN module to core), and by running apps with shuffle encryption enabled. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #15981 from vanzin/SPARK-18547.
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/SecurityManager.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala18
-rw-r--r--core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala135
9 files changed, 149 insertions, 111 deletions
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 199365ad92..87fe563152 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -21,7 +21,6 @@ import java.lang.{Byte => JByte}
import java.net.{Authenticator, PasswordAuthentication}
import java.security.{KeyStore, SecureRandom}
import java.security.cert.X509Certificate
-import javax.crypto.KeyGenerator
import javax.net.ssl._
import com.google.common.hash.HashCodes
@@ -33,7 +32,6 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.sasl.SecretKeyHolder
-import org.apache.spark.security.CryptoStreamUtils._
import org.apache.spark.util.Utils
/**
@@ -185,7 +183,9 @@ import org.apache.spark.util.Utils
* setting `spark.ssl.useNodeLocalConf` to `true`.
*/
-private[spark] class SecurityManager(sparkConf: SparkConf)
+private[spark] class SecurityManager(
+ sparkConf: SparkConf,
+ ioEncryptionKey: Option[Array[Byte]] = None)
extends Logging with SecretKeyHolder {
import SecurityManager._
@@ -415,6 +415,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
logInfo("Changing acls enabled to: " + aclsOn)
}
+ def getIOEncryptionKey(): Option[Array[Byte]] = ioEncryptionKey
+
/**
* Generates or looks up the secret key.
*
@@ -559,19 +561,4 @@ private[spark] object SecurityManager {
// key used to store the spark secret in the Hadoop UGI
val SECRET_LOOKUP_KEY = "sparkCookie"
- /**
- * Setup the cryptographic key used by IO encryption in credentials. The key is generated using
- * [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]].
- */
- def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = {
- if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) {
- val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS)
- val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM)
- val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm)
- keyGen.init(keyLen)
-
- val ioKey = keyGen.generateKey()
- credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded)
- }
- }
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 1261e3e735..a159a170eb 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -422,10 +422,6 @@ class SparkContext(config: SparkConf) extends Logging {
}
if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true")
- if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) {
- throw new SparkException("IO encryption is only supported in YARN mode, please disable it " +
- s"by setting ${IO_ENCRYPTION_ENABLED.key} to false")
- }
// "_jobProgressListener" should be set up before creating SparkEnv because when creating
// "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them.
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 1ffeb12988..1296386ac9 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -36,6 +36,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator}
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
+import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
@@ -165,15 +166,20 @@ object SparkEnv extends Logging {
val bindAddress = conf.get(DRIVER_BIND_ADDRESS)
val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS)
val port = conf.get("spark.driver.port").toInt
+ val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) {
+ Some(CryptoStreamUtils.createKey(conf))
+ } else {
+ None
+ }
create(
conf,
SparkContext.DRIVER_IDENTIFIER,
bindAddress,
advertiseAddress,
port,
- isDriver = true,
- isLocal = isLocal,
- numUsableCores = numCores,
+ isLocal,
+ numCores,
+ ioEncryptionKey,
listenerBus = listenerBus,
mockOutputCommitCoordinator = mockOutputCommitCoordinator
)
@@ -189,6 +195,7 @@ object SparkEnv extends Logging {
hostname: String,
port: Int,
numCores: Int,
+ ioEncryptionKey: Option[Array[Byte]],
isLocal: Boolean): SparkEnv = {
val env = create(
conf,
@@ -196,9 +203,9 @@ object SparkEnv extends Logging {
hostname,
hostname,
port,
- isDriver = false,
- isLocal = isLocal,
- numUsableCores = numCores
+ isLocal,
+ numCores,
+ ioEncryptionKey
)
SparkEnv.set(env)
env
@@ -213,18 +220,26 @@ object SparkEnv extends Logging {
bindAddress: String,
advertiseAddress: String,
port: Int,
- isDriver: Boolean,
isLocal: Boolean,
numUsableCores: Int,
+ ioEncryptionKey: Option[Array[Byte]],
listenerBus: LiveListenerBus = null,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
+ val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER
+
// Listener bus is only used on the driver
if (isDriver) {
assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!")
}
- val securityManager = new SecurityManager(conf)
+ val securityManager = new SecurityManager(conf, ioEncryptionKey)
+ ioEncryptionKey.foreach { _ =>
+ if (!securityManager.isSaslEncryptionEnabled()) {
+ logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " +
+ "wire.")
+ }
+ }
val systemName = if (isDriver) driverSystemName else executorSystemName
val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf,
@@ -270,7 +285,7 @@ object SparkEnv extends Logging {
"spark.serializer", "org.apache.spark.serializer.JavaSerializer")
logDebug(s"Using serializer: ${serializer.getClass}")
- val serializerManager = new SerializerManager(serializer, conf)
+ val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey)
val closureSerializer = new JavaSerializer(conf)
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 7eec4ae64f..92a27902c6 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -200,8 +200,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
new SecurityManager(executorConf),
clientMode = true)
val driver = fetcher.setupEndpointRefByURI(driverUrl)
- val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++
- Seq[(String, String)](("spark.app.id", appId))
+ val cfg = driver.askWithRetry[SparkAppConfig](RetrieveSparkAppConfig)
+ val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId))
fetcher.shutdown()
// Create SparkEnv using properties we fetched from the driver.
@@ -221,7 +221,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
}
val env = SparkEnv.createExecutorEnv(
- driverConf, executorId, hostname, port, cores, isLocal = false)
+ driverConf, executorId, hostname, port, cores, cfg.ioEncryptionKey, isLocal = false)
env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index edc8aac5d1..0a4f19d760 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -28,7 +28,12 @@ private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable
private[spark] object CoarseGrainedClusterMessages {
- case object RetrieveSparkProps extends CoarseGrainedClusterMessage
+ case object RetrieveSparkAppConfig extends CoarseGrainedClusterMessage
+
+ case class SparkAppConfig(
+ sparkProperties: Seq[(String, String)],
+ ioEncryptionKey: Option[Array[Byte]])
+ extends CoarseGrainedClusterMessage
case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 10d55c87fb..3452487e72 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -206,8 +206,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
removeExecutor(executorId, reason)
context.reply(true)
- case RetrieveSparkProps =>
- context.reply(sparkProperties)
+ case RetrieveSparkAppConfig =>
+ val reply = SparkAppConfig(sparkProperties,
+ SparkEnv.get.securityManager.getIOEncryptionKey())
+ context.reply(reply)
}
// Make fake resource offers on all executors
diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
index f41fc38be2..8e3436f134 100644
--- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
+++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
@@ -18,14 +18,13 @@ package org.apache.spark.security
import java.io.{InputStream, OutputStream}
import java.util.Properties
+import javax.crypto.KeyGenerator
import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}
import org.apache.commons.crypto.random._
import org.apache.commons.crypto.stream._
-import org.apache.hadoop.io.Text
import org.apache.spark.SparkConf
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
@@ -33,10 +32,6 @@ import org.apache.spark.internal.config._
* A util class for manipulating IO encryption and decryption streams.
*/
private[spark] object CryptoStreamUtils extends Logging {
- /**
- * Constants and variables for spark IO encryption
- */
- val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN")
// The initialization vector length in bytes.
val IV_LENGTH_IN_BYTES = 16
@@ -50,12 +45,11 @@ private[spark] object CryptoStreamUtils extends Logging {
*/
def createCryptoOutputStream(
os: OutputStream,
- sparkConf: SparkConf): OutputStream = {
+ sparkConf: SparkConf,
+ key: Array[Byte]): OutputStream = {
val properties = toCryptoConf(sparkConf)
val iv = createInitializationVector(properties)
os.write(iv)
- val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
- val key = credentials.getSecretKey(SPARK_IO_TOKEN)
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
new CryptoOutputStream(transformationStr, properties, os,
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
@@ -66,12 +60,11 @@ private[spark] object CryptoStreamUtils extends Logging {
*/
def createCryptoInputStream(
is: InputStream,
- sparkConf: SparkConf): InputStream = {
+ sparkConf: SparkConf,
+ key: Array[Byte]): InputStream = {
val properties = toCryptoConf(sparkConf)
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
is.read(iv, 0, iv.length)
- val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
- val key = credentials.getSecretKey(SPARK_IO_TOKEN)
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
new CryptoInputStream(transformationStr, properties, is,
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
@@ -92,6 +85,17 @@ private[spark] object CryptoStreamUtils extends Logging {
}
/**
+ * Creates a new encryption key.
+ */
+ def createKey(conf: SparkConf): Array[Byte] = {
+ val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS)
+ val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM)
+ val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm)
+ keyGen.init(keyLen)
+ keyGen.generateKey().getEncoded()
+ }
+
+ /**
* This method to generate an IV (Initialization Vector) using secure random.
*/
private[this] def createInitializationVector(properties: Properties): Array[Byte] = {
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 2156d576f1..ef8432ec08 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -33,7 +33,12 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea
* Component which configures serialization, compression and encryption for various Spark
* components, including automatic selection of which [[Serializer]] to use for shuffles.
*/
-private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) {
+private[spark] class SerializerManager(
+ defaultSerializer: Serializer,
+ conf: SparkConf,
+ encryptionKey: Option[Array[Byte]]) {
+
+ def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None)
private[this] val kryoSerializer = new KryoSerializer(conf)
@@ -63,9 +68,6 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
// Whether to compress shuffle output temporarily spilled to disk
private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)
- // Whether to enable IO encryption
- private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED)
-
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
* the initialization of the compression codec until it is first used. The reason is that a Spark
* program could be using a user-defined codec in a third party jar, which is loaded in
@@ -125,14 +127,18 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
* Wrap an input stream for encryption if shuffle encryption is enabled
*/
private[this] def wrapForEncryption(s: InputStream): InputStream = {
- if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s
+ encryptionKey
+ .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) }
+ .getOrElse(s)
}
/**
* Wrap an output stream for encryption if shuffle encryption is enabled
*/
private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
- if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s
+ encryptionKey
+ .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) }
+ .getOrElse(s)
}
/**
diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
index 81eb907ac7..a61ec74c7d 100644
--- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
@@ -16,18 +16,21 @@
*/
package org.apache.spark.security
-import java.security.PrivilegedExceptionAction
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.nio.charset.StandardCharsets.UTF_8
+import java.util.UUID
-import org.apache.hadoop.security.{Credentials, UserGroupInformation}
+import com.google.common.io.ByteStreams
-import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark._
import org.apache.spark.internal.config._
import org.apache.spark.security.CryptoStreamUtils._
+import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
+import org.apache.spark.storage.TempShuffleBlockId
class CryptoStreamUtilsSuite extends SparkFunSuite {
- val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup"))
- test("Crypto configuration conversion") {
+ test("crypto configuration conversion") {
val sparkKey1 = s"${SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX}a.b.c"
val sparkVal1 = "val1"
val cryptoKey1 = s"${COMMONS_CRYPTO_CONF_PREFIX}a.b.c"
@@ -43,65 +46,85 @@ class CryptoStreamUtilsSuite extends SparkFunSuite {
assert(!props.containsKey(cryptoKey2))
}
- test("Shuffle encryption is disabled by default") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- initCredentials(conf, credentials)
- assert(credentials.getSecretKey(SPARK_IO_TOKEN) === null)
- }
- })
+ test("shuffle encryption key length should be 128 by default") {
+ val conf = createConf()
+ var key = CryptoStreamUtils.createKey(conf)
+ val actual = key.length * (java.lang.Byte.SIZE)
+ assert(actual === 128)
}
- test("Shuffle encryption key length should be 128 by default") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- conf.set(IO_ENCRYPTION_ENABLED, true)
- initCredentials(conf, credentials)
- var key = credentials.getSecretKey(SPARK_IO_TOKEN)
- assert(key !== null)
- val actual = key.length * (java.lang.Byte.SIZE)
- assert(actual === 128)
- }
- })
+ test("create 256-bit key") {
+ val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "256")
+ var key = CryptoStreamUtils.createKey(conf)
+ val actual = key.length * (java.lang.Byte.SIZE)
+ assert(actual === 256)
}
- test("Initial credentials with key length in 256") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 256)
- conf.set(IO_ENCRYPTION_ENABLED, true)
- initCredentials(conf, credentials)
- var key = credentials.getSecretKey(SPARK_IO_TOKEN)
- assert(key !== null)
- val actual = key.length * (java.lang.Byte.SIZE)
- assert(actual === 256)
- }
- })
+ test("create key with invalid length") {
+ intercept[IllegalArgumentException] {
+ val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "328")
+ CryptoStreamUtils.createKey(conf)
+ }
}
- test("Initial credentials with invalid key length") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 328)
- conf.set(IO_ENCRYPTION_ENABLED, true)
- val thrown = intercept[IllegalArgumentException] {
- initCredentials(conf, credentials)
- }
- }
- })
+ test("serializer manager integration") {
+ val conf = createConf()
+ .set("spark.shuffle.compress", "true")
+ .set("spark.shuffle.spill.compress", "true")
+
+ val plainStr = "hello world"
+ val blockId = new TempShuffleBlockId(UUID.randomUUID())
+ val key = Some(CryptoStreamUtils.createKey(conf))
+ val serializerManager = new SerializerManager(new JavaSerializer(conf), conf,
+ encryptionKey = key)
+
+ val outputStream = new ByteArrayOutputStream()
+ val wrappedOutputStream = serializerManager.wrapStream(blockId, outputStream)
+ wrappedOutputStream.write(plainStr.getBytes(UTF_8))
+ wrappedOutputStream.close()
+
+ val encryptedBytes = outputStream.toByteArray
+ val encryptedStr = new String(encryptedBytes, UTF_8)
+ assert(plainStr !== encryptedStr)
+
+ val inputStream = new ByteArrayInputStream(encryptedBytes)
+ val wrappedInputStream = serializerManager.wrapStream(blockId, inputStream)
+ val decryptedBytes = ByteStreams.toByteArray(wrappedInputStream)
+ val decryptedStr = new String(decryptedBytes, UTF_8)
+ assert(decryptedStr === plainStr)
}
- private[this] def initCredentials(conf: SparkConf, credentials: Credentials): Unit = {
- if (conf.get(IO_ENCRYPTION_ENABLED)) {
- SecurityManager.initIOEncryptionKey(conf, credentials)
+ test("encryption key propagation to executors") {
+ val conf = createConf().setAppName("Crypto Test").setMaster("local-cluster[1,1,1024]")
+ val sc = new SparkContext(conf)
+ try {
+ val content = "This is the content to be encrypted."
+ val encrypted = sc.parallelize(Seq(1))
+ .map { str =>
+ val bytes = new ByteArrayOutputStream()
+ val out = CryptoStreamUtils.createCryptoOutputStream(bytes, SparkEnv.get.conf,
+ SparkEnv.get.securityManager.getIOEncryptionKey().get)
+ out.write(content.getBytes(UTF_8))
+ out.close()
+ bytes.toByteArray()
+ }.collect()(0)
+
+ assert(content != encrypted)
+
+ val in = CryptoStreamUtils.createCryptoInputStream(new ByteArrayInputStream(encrypted),
+ sc.conf, SparkEnv.get.securityManager.getIOEncryptionKey().get)
+ val decrypted = new String(ByteStreams.toByteArray(in), UTF_8)
+ assert(content === decrypted)
+ } finally {
+ sc.stop()
}
}
+
+ private def createConf(extra: (String, String)*): SparkConf = {
+ val conf = new SparkConf()
+ extra.foreach { case (k, v) => conf.set(k, v) }
+ conf.set(IO_ENCRYPTION_ENABLED, true)
+ conf
+ }
+
}