aboutsummaryrefslogtreecommitdiff
path: root/core/src/test
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2016-11-28 21:10:57 -0800
committerShixiong Zhu <shixiong@databricks.com>2016-11-28 21:10:57 -0800
commit8b325b17ecdf013b7a6edcb7ee3773546bd914df (patch)
treee2826f751402537582646f88fe3b905783fa2f7e /core/src/test
parent1633ff3b6c97e33191859f34c868782cbb0972fd (diff)
downloadspark-8b325b17ecdf013b7a6edcb7ee3773546bd914df.tar.gz
spark-8b325b17ecdf013b7a6edcb7ee3773546bd914df.tar.bz2
spark-8b325b17ecdf013b7a6edcb7ee3773546bd914df.zip
[SPARK-18547][CORE] Propagate I/O encryption key when executors register.
This change modifies the method used to propagate encryption keys used during shuffle. Instead of relying on YARN's UserGroupInformation credential propagation, this change explicitly distributes the key using the messages exchanged between driver and executor during registration. When RPC encryption is enabled, this means key propagation is also secure. This allows shuffle encryption to work in non-YARN mode, which means that it's easier to write unit tests for areas of the code that are affected by the feature. The key is stored in the SecurityManager; because there are many instances of that class used in the code, the key is only guaranteed to exist in the instance managed by the SparkEnv. This path was chosen to avoid storing the key in the SparkConf, which would risk having the key being written to disk as part of the configuration (as, for example, is done when starting YARN applications). Tested by new and existing unit tests (which were moved from the YARN module to core), and by running apps with shuffle encryption enabled. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #15981 from vanzin/SPARK-18547.
Diffstat (limited to 'core/src/test')
-rw-r--r--core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala135
1 files changed, 79 insertions, 56 deletions
diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
index 81eb907ac7..a61ec74c7d 100644
--- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
@@ -16,18 +16,21 @@
*/
package org.apache.spark.security
-import java.security.PrivilegedExceptionAction
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.nio.charset.StandardCharsets.UTF_8
+import java.util.UUID
-import org.apache.hadoop.security.{Credentials, UserGroupInformation}
+import com.google.common.io.ByteStreams
-import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark._
import org.apache.spark.internal.config._
import org.apache.spark.security.CryptoStreamUtils._
+import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
+import org.apache.spark.storage.TempShuffleBlockId
class CryptoStreamUtilsSuite extends SparkFunSuite {
- val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup"))
- test("Crypto configuration conversion") {
+ test("crypto configuration conversion") {
val sparkKey1 = s"${SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX}a.b.c"
val sparkVal1 = "val1"
val cryptoKey1 = s"${COMMONS_CRYPTO_CONF_PREFIX}a.b.c"
@@ -43,65 +46,85 @@ class CryptoStreamUtilsSuite extends SparkFunSuite {
assert(!props.containsKey(cryptoKey2))
}
- test("Shuffle encryption is disabled by default") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- initCredentials(conf, credentials)
- assert(credentials.getSecretKey(SPARK_IO_TOKEN) === null)
- }
- })
+ test("shuffle encryption key length should be 128 by default") {
+ val conf = createConf()
+ var key = CryptoStreamUtils.createKey(conf)
+ val actual = key.length * (java.lang.Byte.SIZE)
+ assert(actual === 128)
}
- test("Shuffle encryption key length should be 128 by default") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- conf.set(IO_ENCRYPTION_ENABLED, true)
- initCredentials(conf, credentials)
- var key = credentials.getSecretKey(SPARK_IO_TOKEN)
- assert(key !== null)
- val actual = key.length * (java.lang.Byte.SIZE)
- assert(actual === 128)
- }
- })
+ test("create 256-bit key") {
+ val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "256")
+ var key = CryptoStreamUtils.createKey(conf)
+ val actual = key.length * (java.lang.Byte.SIZE)
+ assert(actual === 256)
}
- test("Initial credentials with key length in 256") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 256)
- conf.set(IO_ENCRYPTION_ENABLED, true)
- initCredentials(conf, credentials)
- var key = credentials.getSecretKey(SPARK_IO_TOKEN)
- assert(key !== null)
- val actual = key.length * (java.lang.Byte.SIZE)
- assert(actual === 256)
- }
- })
+ test("create key with invalid length") {
+ intercept[IllegalArgumentException] {
+ val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "328")
+ CryptoStreamUtils.createKey(conf)
+ }
}
- test("Initial credentials with invalid key length") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 328)
- conf.set(IO_ENCRYPTION_ENABLED, true)
- val thrown = intercept[IllegalArgumentException] {
- initCredentials(conf, credentials)
- }
- }
- })
+ test("serializer manager integration") {
+ val conf = createConf()
+ .set("spark.shuffle.compress", "true")
+ .set("spark.shuffle.spill.compress", "true")
+
+ val plainStr = "hello world"
+ val blockId = new TempShuffleBlockId(UUID.randomUUID())
+ val key = Some(CryptoStreamUtils.createKey(conf))
+ val serializerManager = new SerializerManager(new JavaSerializer(conf), conf,
+ encryptionKey = key)
+
+ val outputStream = new ByteArrayOutputStream()
+ val wrappedOutputStream = serializerManager.wrapStream(blockId, outputStream)
+ wrappedOutputStream.write(plainStr.getBytes(UTF_8))
+ wrappedOutputStream.close()
+
+ val encryptedBytes = outputStream.toByteArray
+ val encryptedStr = new String(encryptedBytes, UTF_8)
+ assert(plainStr !== encryptedStr)
+
+ val inputStream = new ByteArrayInputStream(encryptedBytes)
+ val wrappedInputStream = serializerManager.wrapStream(blockId, inputStream)
+ val decryptedBytes = ByteStreams.toByteArray(wrappedInputStream)
+ val decryptedStr = new String(decryptedBytes, UTF_8)
+ assert(decryptedStr === plainStr)
}
- private[this] def initCredentials(conf: SparkConf, credentials: Credentials): Unit = {
- if (conf.get(IO_ENCRYPTION_ENABLED)) {
- SecurityManager.initIOEncryptionKey(conf, credentials)
+ test("encryption key propagation to executors") {
+ val conf = createConf().setAppName("Crypto Test").setMaster("local-cluster[1,1,1024]")
+ val sc = new SparkContext(conf)
+ try {
+ val content = "This is the content to be encrypted."
+ val encrypted = sc.parallelize(Seq(1))
+ .map { str =>
+ val bytes = new ByteArrayOutputStream()
+ val out = CryptoStreamUtils.createCryptoOutputStream(bytes, SparkEnv.get.conf,
+ SparkEnv.get.securityManager.getIOEncryptionKey().get)
+ out.write(content.getBytes(UTF_8))
+ out.close()
+ bytes.toByteArray()
+ }.collect()(0)
+
+ assert(content != encrypted)
+
+ val in = CryptoStreamUtils.createCryptoInputStream(new ByteArrayInputStream(encrypted),
+ sc.conf, SparkEnv.get.securityManager.getIOEncryptionKey().get)
+ val decrypted = new String(ByteStreams.toByteArray(in), UTF_8)
+ assert(content === decrypted)
+ } finally {
+ sc.stop()
}
}
+
+ private def createConf(extra: (String, String)*): SparkConf = {
+ val conf = new SparkConf()
+ extra.foreach { case (k, v) => conf.set(k, v) }
+ conf.set(IO_ENCRYPTION_ENABLED, true)
+ conf
+ }
+
}