aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/SecurityManager.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala18
-rw-r--r--core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala135
9 files changed, 149 insertions, 111 deletions
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 199365ad92..87fe563152 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -21,7 +21,6 @@ import java.lang.{Byte => JByte}
import java.net.{Authenticator, PasswordAuthentication}
import java.security.{KeyStore, SecureRandom}
import java.security.cert.X509Certificate
-import javax.crypto.KeyGenerator
import javax.net.ssl._
import com.google.common.hash.HashCodes
@@ -33,7 +32,6 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.sasl.SecretKeyHolder
-import org.apache.spark.security.CryptoStreamUtils._
import org.apache.spark.util.Utils
/**
@@ -185,7 +183,9 @@ import org.apache.spark.util.Utils
* setting `spark.ssl.useNodeLocalConf` to `true`.
*/
-private[spark] class SecurityManager(sparkConf: SparkConf)
+private[spark] class SecurityManager(
+ sparkConf: SparkConf,
+ ioEncryptionKey: Option[Array[Byte]] = None)
extends Logging with SecretKeyHolder {
import SecurityManager._
@@ -415,6 +415,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
logInfo("Changing acls enabled to: " + aclsOn)
}
+ def getIOEncryptionKey(): Option[Array[Byte]] = ioEncryptionKey
+
/**
* Generates or looks up the secret key.
*
@@ -559,19 +561,4 @@ private[spark] object SecurityManager {
// key used to store the spark secret in the Hadoop UGI
val SECRET_LOOKUP_KEY = "sparkCookie"
- /**
- * Setup the cryptographic key used by IO encryption in credentials. The key is generated using
- * [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]].
- */
- def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = {
- if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) {
- val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS)
- val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM)
- val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm)
- keyGen.init(keyLen)
-
- val ioKey = keyGen.generateKey()
- credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded)
- }
- }
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 1261e3e735..a159a170eb 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -422,10 +422,6 @@ class SparkContext(config: SparkConf) extends Logging {
}
if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true")
- if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) {
- throw new SparkException("IO encryption is only supported in YARN mode, please disable it " +
- s"by setting ${IO_ENCRYPTION_ENABLED.key} to false")
- }
// "_jobProgressListener" should be set up before creating SparkEnv because when creating
// "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them.
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 1ffeb12988..1296386ac9 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -36,6 +36,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator}
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
+import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
@@ -165,15 +166,20 @@ object SparkEnv extends Logging {
val bindAddress = conf.get(DRIVER_BIND_ADDRESS)
val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS)
val port = conf.get("spark.driver.port").toInt
+ val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) {
+ Some(CryptoStreamUtils.createKey(conf))
+ } else {
+ None
+ }
create(
conf,
SparkContext.DRIVER_IDENTIFIER,
bindAddress,
advertiseAddress,
port,
- isDriver = true,
- isLocal = isLocal,
- numUsableCores = numCores,
+ isLocal,
+ numCores,
+ ioEncryptionKey,
listenerBus = listenerBus,
mockOutputCommitCoordinator = mockOutputCommitCoordinator
)
@@ -189,6 +195,7 @@ object SparkEnv extends Logging {
hostname: String,
port: Int,
numCores: Int,
+ ioEncryptionKey: Option[Array[Byte]],
isLocal: Boolean): SparkEnv = {
val env = create(
conf,
@@ -196,9 +203,9 @@ object SparkEnv extends Logging {
hostname,
hostname,
port,
- isDriver = false,
- isLocal = isLocal,
- numUsableCores = numCores
+ isLocal,
+ numCores,
+ ioEncryptionKey
)
SparkEnv.set(env)
env
@@ -213,18 +220,26 @@ object SparkEnv extends Logging {
bindAddress: String,
advertiseAddress: String,
port: Int,
- isDriver: Boolean,
isLocal: Boolean,
numUsableCores: Int,
+ ioEncryptionKey: Option[Array[Byte]],
listenerBus: LiveListenerBus = null,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
+ val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER
+
// Listener bus is only used on the driver
if (isDriver) {
assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!")
}
- val securityManager = new SecurityManager(conf)
+ val securityManager = new SecurityManager(conf, ioEncryptionKey)
+ ioEncryptionKey.foreach { _ =>
+ if (!securityManager.isSaslEncryptionEnabled()) {
+ logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " +
+ "wire.")
+ }
+ }
val systemName = if (isDriver) driverSystemName else executorSystemName
val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf,
@@ -270,7 +285,7 @@ object SparkEnv extends Logging {
"spark.serializer", "org.apache.spark.serializer.JavaSerializer")
logDebug(s"Using serializer: ${serializer.getClass}")
- val serializerManager = new SerializerManager(serializer, conf)
+ val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey)
val closureSerializer = new JavaSerializer(conf)
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 7eec4ae64f..92a27902c6 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -200,8 +200,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
new SecurityManager(executorConf),
clientMode = true)
val driver = fetcher.setupEndpointRefByURI(driverUrl)
- val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++
- Seq[(String, String)](("spark.app.id", appId))
+ val cfg = driver.askWithRetry[SparkAppConfig](RetrieveSparkAppConfig)
+ val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId))
fetcher.shutdown()
// Create SparkEnv using properties we fetched from the driver.
@@ -221,7 +221,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
}
val env = SparkEnv.createExecutorEnv(
- driverConf, executorId, hostname, port, cores, isLocal = false)
+ driverConf, executorId, hostname, port, cores, cfg.ioEncryptionKey, isLocal = false)
env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index edc8aac5d1..0a4f19d760 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -28,7 +28,12 @@ private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable
private[spark] object CoarseGrainedClusterMessages {
- case object RetrieveSparkProps extends CoarseGrainedClusterMessage
+ case object RetrieveSparkAppConfig extends CoarseGrainedClusterMessage
+
+ case class SparkAppConfig(
+ sparkProperties: Seq[(String, String)],
+ ioEncryptionKey: Option[Array[Byte]])
+ extends CoarseGrainedClusterMessage
case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 10d55c87fb..3452487e72 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -206,8 +206,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
removeExecutor(executorId, reason)
context.reply(true)
- case RetrieveSparkProps =>
- context.reply(sparkProperties)
+ case RetrieveSparkAppConfig =>
+ val reply = SparkAppConfig(sparkProperties,
+ SparkEnv.get.securityManager.getIOEncryptionKey())
+ context.reply(reply)
}
// Make fake resource offers on all executors
diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
index f41fc38be2..8e3436f134 100644
--- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
+++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
@@ -18,14 +18,13 @@ package org.apache.spark.security
import java.io.{InputStream, OutputStream}
import java.util.Properties
+import javax.crypto.KeyGenerator
import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}
import org.apache.commons.crypto.random._
import org.apache.commons.crypto.stream._
-import org.apache.hadoop.io.Text
import org.apache.spark.SparkConf
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
@@ -33,10 +32,6 @@ import org.apache.spark.internal.config._
* A util class for manipulating IO encryption and decryption streams.
*/
private[spark] object CryptoStreamUtils extends Logging {
- /**
- * Constants and variables for spark IO encryption
- */
- val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN")
// The initialization vector length in bytes.
val IV_LENGTH_IN_BYTES = 16
@@ -50,12 +45,11 @@ private[spark] object CryptoStreamUtils extends Logging {
*/
def createCryptoOutputStream(
os: OutputStream,
- sparkConf: SparkConf): OutputStream = {
+ sparkConf: SparkConf,
+ key: Array[Byte]): OutputStream = {
val properties = toCryptoConf(sparkConf)
val iv = createInitializationVector(properties)
os.write(iv)
- val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
- val key = credentials.getSecretKey(SPARK_IO_TOKEN)
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
new CryptoOutputStream(transformationStr, properties, os,
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
@@ -66,12 +60,11 @@ private[spark] object CryptoStreamUtils extends Logging {
*/
def createCryptoInputStream(
is: InputStream,
- sparkConf: SparkConf): InputStream = {
+ sparkConf: SparkConf,
+ key: Array[Byte]): InputStream = {
val properties = toCryptoConf(sparkConf)
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
is.read(iv, 0, iv.length)
- val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
- val key = credentials.getSecretKey(SPARK_IO_TOKEN)
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
new CryptoInputStream(transformationStr, properties, is,
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
@@ -92,6 +85,17 @@ private[spark] object CryptoStreamUtils extends Logging {
}
/**
+ * Creates a new encryption key.
+ */
+ def createKey(conf: SparkConf): Array[Byte] = {
+ val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS)
+ val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM)
+ val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm)
+ keyGen.init(keyLen)
+ keyGen.generateKey().getEncoded()
+ }
+
+ /**
* This method to generate an IV (Initialization Vector) using secure random.
*/
private[this] def createInitializationVector(properties: Properties): Array[Byte] = {
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 2156d576f1..ef8432ec08 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -33,7 +33,12 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea
* Component which configures serialization, compression and encryption for various Spark
* components, including automatic selection of which [[Serializer]] to use for shuffles.
*/
-private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) {
+private[spark] class SerializerManager(
+ defaultSerializer: Serializer,
+ conf: SparkConf,
+ encryptionKey: Option[Array[Byte]]) {
+
+ def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None)
private[this] val kryoSerializer = new KryoSerializer(conf)
@@ -63,9 +68,6 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
// Whether to compress shuffle output temporarily spilled to disk
private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)
- // Whether to enable IO encryption
- private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED)
-
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
* the initialization of the compression codec until it is first used. The reason is that a Spark
* program could be using a user-defined codec in a third party jar, which is loaded in
@@ -125,14 +127,18 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
* Wrap an input stream for encryption if shuffle encryption is enabled
*/
private[this] def wrapForEncryption(s: InputStream): InputStream = {
- if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s
+ encryptionKey
+ .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) }
+ .getOrElse(s)
}
/**
* Wrap an output stream for encryption if shuffle encryption is enabled
*/
private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
- if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s
+ encryptionKey
+ .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) }
+ .getOrElse(s)
}
/**
diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
index 81eb907ac7..a61ec74c7d 100644
--- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
@@ -16,18 +16,21 @@
*/
package org.apache.spark.security
-import java.security.PrivilegedExceptionAction
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.nio.charset.StandardCharsets.UTF_8
+import java.util.UUID
-import org.apache.hadoop.security.{Credentials, UserGroupInformation}
+import com.google.common.io.ByteStreams
-import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark._
import org.apache.spark.internal.config._
import org.apache.spark.security.CryptoStreamUtils._
+import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
+import org.apache.spark.storage.TempShuffleBlockId
class CryptoStreamUtilsSuite extends SparkFunSuite {
- val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup"))
- test("Crypto configuration conversion") {
+ test("crypto configuration conversion") {
val sparkKey1 = s"${SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX}a.b.c"
val sparkVal1 = "val1"
val cryptoKey1 = s"${COMMONS_CRYPTO_CONF_PREFIX}a.b.c"
@@ -43,65 +46,85 @@ class CryptoStreamUtilsSuite extends SparkFunSuite {
assert(!props.containsKey(cryptoKey2))
}
- test("Shuffle encryption is disabled by default") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- initCredentials(conf, credentials)
- assert(credentials.getSecretKey(SPARK_IO_TOKEN) === null)
- }
- })
+ test("shuffle encryption key length should be 128 by default") {
+ val conf = createConf()
+ var key = CryptoStreamUtils.createKey(conf)
+ val actual = key.length * (java.lang.Byte.SIZE)
+ assert(actual === 128)
}
- test("Shuffle encryption key length should be 128 by default") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- conf.set(IO_ENCRYPTION_ENABLED, true)
- initCredentials(conf, credentials)
- var key = credentials.getSecretKey(SPARK_IO_TOKEN)
- assert(key !== null)
- val actual = key.length * (java.lang.Byte.SIZE)
- assert(actual === 128)
- }
- })
+ test("create 256-bit key") {
+ val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "256")
+ var key = CryptoStreamUtils.createKey(conf)
+ val actual = key.length * (java.lang.Byte.SIZE)
+ assert(actual === 256)
}
- test("Initial credentials with key length in 256") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 256)
- conf.set(IO_ENCRYPTION_ENABLED, true)
- initCredentials(conf, credentials)
- var key = credentials.getSecretKey(SPARK_IO_TOKEN)
- assert(key !== null)
- val actual = key.length * (java.lang.Byte.SIZE)
- assert(actual === 256)
- }
- })
+ test("create key with invalid length") {
+ intercept[IllegalArgumentException] {
+ val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "328")
+ CryptoStreamUtils.createKey(conf)
+ }
}
- test("Initial credentials with invalid key length") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 328)
- conf.set(IO_ENCRYPTION_ENABLED, true)
- val thrown = intercept[IllegalArgumentException] {
- initCredentials(conf, credentials)
- }
- }
- })
+ test("serializer manager integration") {
+ val conf = createConf()
+ .set("spark.shuffle.compress", "true")
+ .set("spark.shuffle.spill.compress", "true")
+
+ val plainStr = "hello world"
+ val blockId = new TempShuffleBlockId(UUID.randomUUID())
+ val key = Some(CryptoStreamUtils.createKey(conf))
+ val serializerManager = new SerializerManager(new JavaSerializer(conf), conf,
+ encryptionKey = key)
+
+ val outputStream = new ByteArrayOutputStream()
+ val wrappedOutputStream = serializerManager.wrapStream(blockId, outputStream)
+ wrappedOutputStream.write(plainStr.getBytes(UTF_8))
+ wrappedOutputStream.close()
+
+ val encryptedBytes = outputStream.toByteArray
+ val encryptedStr = new String(encryptedBytes, UTF_8)
+ assert(plainStr !== encryptedStr)
+
+ val inputStream = new ByteArrayInputStream(encryptedBytes)
+ val wrappedInputStream = serializerManager.wrapStream(blockId, inputStream)
+ val decryptedBytes = ByteStreams.toByteArray(wrappedInputStream)
+ val decryptedStr = new String(decryptedBytes, UTF_8)
+ assert(decryptedStr === plainStr)
}
- private[this] def initCredentials(conf: SparkConf, credentials: Credentials): Unit = {
- if (conf.get(IO_ENCRYPTION_ENABLED)) {
- SecurityManager.initIOEncryptionKey(conf, credentials)
+ test("encryption key propagation to executors") {
+ val conf = createConf().setAppName("Crypto Test").setMaster("local-cluster[1,1,1024]")
+ val sc = new SparkContext(conf)
+ try {
+ val content = "This is the content to be encrypted."
+ val encrypted = sc.parallelize(Seq(1))
+ .map { str =>
+ val bytes = new ByteArrayOutputStream()
+ val out = CryptoStreamUtils.createCryptoOutputStream(bytes, SparkEnv.get.conf,
+ SparkEnv.get.securityManager.getIOEncryptionKey().get)
+ out.write(content.getBytes(UTF_8))
+ out.close()
+ bytes.toByteArray()
+ }.collect()(0)
+
+ assert(content != encrypted)
+
+ val in = CryptoStreamUtils.createCryptoInputStream(new ByteArrayInputStream(encrypted),
+ sc.conf, SparkEnv.get.securityManager.getIOEncryptionKey().get)
+ val decrypted = new String(ByteStreams.toByteArray(in), UTF_8)
+ assert(content === decrypted)
+ } finally {
+ sc.stop()
}
}
+
+ private def createConf(extra: (String, String)*): SparkConf = {
+ val conf = new SparkConf()
+ extra.foreach { case (k, v) => conf.set(k, v) }
+ conf.set(IO_ENCRYPTION_ENABLED, true)
+ conf
+ }
+
}