aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SecurityManager.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala18
-rw-r--r--core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala135
-rw-r--r--docs/configuration.md3
-rw-r--r--mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala2
-rw-r--r--mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala4
-rw-r--r--mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManagerSuite.scala11
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala5
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala108
15 files changed, 166 insertions, 227 deletions
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 199365ad92..87fe563152 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -21,7 +21,6 @@ import java.lang.{Byte => JByte}
import java.net.{Authenticator, PasswordAuthentication}
import java.security.{KeyStore, SecureRandom}
import java.security.cert.X509Certificate
-import javax.crypto.KeyGenerator
import javax.net.ssl._
import com.google.common.hash.HashCodes
@@ -33,7 +32,6 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.sasl.SecretKeyHolder
-import org.apache.spark.security.CryptoStreamUtils._
import org.apache.spark.util.Utils
/**
@@ -185,7 +183,9 @@ import org.apache.spark.util.Utils
* setting `spark.ssl.useNodeLocalConf` to `true`.
*/
-private[spark] class SecurityManager(sparkConf: SparkConf)
+private[spark] class SecurityManager(
+ sparkConf: SparkConf,
+ ioEncryptionKey: Option[Array[Byte]] = None)
extends Logging with SecretKeyHolder {
import SecurityManager._
@@ -415,6 +415,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
logInfo("Changing acls enabled to: " + aclsOn)
}
+ def getIOEncryptionKey(): Option[Array[Byte]] = ioEncryptionKey
+
/**
* Generates or looks up the secret key.
*
@@ -559,19 +561,4 @@ private[spark] object SecurityManager {
// key used to store the spark secret in the Hadoop UGI
val SECRET_LOOKUP_KEY = "sparkCookie"
- /**
- * Setup the cryptographic key used by IO encryption in credentials. The key is generated using
- * [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]].
- */
- def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = {
- if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) {
- val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS)
- val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM)
- val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm)
- keyGen.init(keyLen)
-
- val ioKey = keyGen.generateKey()
- credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded)
- }
- }
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 1261e3e735..a159a170eb 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -422,10 +422,6 @@ class SparkContext(config: SparkConf) extends Logging {
}
if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true")
- if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) {
- throw new SparkException("IO encryption is only supported in YARN mode, please disable it " +
- s"by setting ${IO_ENCRYPTION_ENABLED.key} to false")
- }
// "_jobProgressListener" should be set up before creating SparkEnv because when creating
// "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them.
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 1ffeb12988..1296386ac9 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -36,6 +36,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator}
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
+import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
@@ -165,15 +166,20 @@ object SparkEnv extends Logging {
val bindAddress = conf.get(DRIVER_BIND_ADDRESS)
val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS)
val port = conf.get("spark.driver.port").toInt
+ val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) {
+ Some(CryptoStreamUtils.createKey(conf))
+ } else {
+ None
+ }
create(
conf,
SparkContext.DRIVER_IDENTIFIER,
bindAddress,
advertiseAddress,
port,
- isDriver = true,
- isLocal = isLocal,
- numUsableCores = numCores,
+ isLocal,
+ numCores,
+ ioEncryptionKey,
listenerBus = listenerBus,
mockOutputCommitCoordinator = mockOutputCommitCoordinator
)
@@ -189,6 +195,7 @@ object SparkEnv extends Logging {
hostname: String,
port: Int,
numCores: Int,
+ ioEncryptionKey: Option[Array[Byte]],
isLocal: Boolean): SparkEnv = {
val env = create(
conf,
@@ -196,9 +203,9 @@ object SparkEnv extends Logging {
hostname,
hostname,
port,
- isDriver = false,
- isLocal = isLocal,
- numUsableCores = numCores
+ isLocal,
+ numCores,
+ ioEncryptionKey
)
SparkEnv.set(env)
env
@@ -213,18 +220,26 @@ object SparkEnv extends Logging {
bindAddress: String,
advertiseAddress: String,
port: Int,
- isDriver: Boolean,
isLocal: Boolean,
numUsableCores: Int,
+ ioEncryptionKey: Option[Array[Byte]],
listenerBus: LiveListenerBus = null,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
+ val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER
+
// Listener bus is only used on the driver
if (isDriver) {
assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!")
}
- val securityManager = new SecurityManager(conf)
+ val securityManager = new SecurityManager(conf, ioEncryptionKey)
+ ioEncryptionKey.foreach { _ =>
+ if (!securityManager.isSaslEncryptionEnabled()) {
+ logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " +
+ "wire.")
+ }
+ }
val systemName = if (isDriver) driverSystemName else executorSystemName
val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf,
@@ -270,7 +285,7 @@ object SparkEnv extends Logging {
"spark.serializer", "org.apache.spark.serializer.JavaSerializer")
logDebug(s"Using serializer: ${serializer.getClass}")
- val serializerManager = new SerializerManager(serializer, conf)
+ val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey)
val closureSerializer = new JavaSerializer(conf)
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 7eec4ae64f..92a27902c6 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -200,8 +200,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
new SecurityManager(executorConf),
clientMode = true)
val driver = fetcher.setupEndpointRefByURI(driverUrl)
- val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++
- Seq[(String, String)](("spark.app.id", appId))
+ val cfg = driver.askWithRetry[SparkAppConfig](RetrieveSparkAppConfig)
+ val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId))
fetcher.shutdown()
// Create SparkEnv using properties we fetched from the driver.
@@ -221,7 +221,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
}
val env = SparkEnv.createExecutorEnv(
- driverConf, executorId, hostname, port, cores, isLocal = false)
+ driverConf, executorId, hostname, port, cores, cfg.ioEncryptionKey, isLocal = false)
env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index edc8aac5d1..0a4f19d760 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -28,7 +28,12 @@ private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable
private[spark] object CoarseGrainedClusterMessages {
- case object RetrieveSparkProps extends CoarseGrainedClusterMessage
+ case object RetrieveSparkAppConfig extends CoarseGrainedClusterMessage
+
+ case class SparkAppConfig(
+ sparkProperties: Seq[(String, String)],
+ ioEncryptionKey: Option[Array[Byte]])
+ extends CoarseGrainedClusterMessage
case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 10d55c87fb..3452487e72 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -206,8 +206,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
removeExecutor(executorId, reason)
context.reply(true)
- case RetrieveSparkProps =>
- context.reply(sparkProperties)
+ case RetrieveSparkAppConfig =>
+ val reply = SparkAppConfig(sparkProperties,
+ SparkEnv.get.securityManager.getIOEncryptionKey())
+ context.reply(reply)
}
// Make fake resource offers on all executors
diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
index f41fc38be2..8e3436f134 100644
--- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
+++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
@@ -18,14 +18,13 @@ package org.apache.spark.security
import java.io.{InputStream, OutputStream}
import java.util.Properties
+import javax.crypto.KeyGenerator
import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}
import org.apache.commons.crypto.random._
import org.apache.commons.crypto.stream._
-import org.apache.hadoop.io.Text
import org.apache.spark.SparkConf
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
@@ -33,10 +32,6 @@ import org.apache.spark.internal.config._
* A util class for manipulating IO encryption and decryption streams.
*/
private[spark] object CryptoStreamUtils extends Logging {
- /**
- * Constants and variables for spark IO encryption
- */
- val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN")
// The initialization vector length in bytes.
val IV_LENGTH_IN_BYTES = 16
@@ -50,12 +45,11 @@ private[spark] object CryptoStreamUtils extends Logging {
*/
def createCryptoOutputStream(
os: OutputStream,
- sparkConf: SparkConf): OutputStream = {
+ sparkConf: SparkConf,
+ key: Array[Byte]): OutputStream = {
val properties = toCryptoConf(sparkConf)
val iv = createInitializationVector(properties)
os.write(iv)
- val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
- val key = credentials.getSecretKey(SPARK_IO_TOKEN)
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
new CryptoOutputStream(transformationStr, properties, os,
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
@@ -66,12 +60,11 @@ private[spark] object CryptoStreamUtils extends Logging {
*/
def createCryptoInputStream(
is: InputStream,
- sparkConf: SparkConf): InputStream = {
+ sparkConf: SparkConf,
+ key: Array[Byte]): InputStream = {
val properties = toCryptoConf(sparkConf)
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
is.read(iv, 0, iv.length)
- val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
- val key = credentials.getSecretKey(SPARK_IO_TOKEN)
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
new CryptoInputStream(transformationStr, properties, is,
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
@@ -92,6 +85,17 @@ private[spark] object CryptoStreamUtils extends Logging {
}
/**
+ * Creates a new encryption key.
+ */
+ def createKey(conf: SparkConf): Array[Byte] = {
+ val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS)
+ val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM)
+ val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm)
+ keyGen.init(keyLen)
+ keyGen.generateKey().getEncoded()
+ }
+
+ /**
* This method to generate an IV (Initialization Vector) using secure random.
*/
private[this] def createInitializationVector(properties: Properties): Array[Byte] = {
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 2156d576f1..ef8432ec08 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -33,7 +33,12 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea
* Component which configures serialization, compression and encryption for various Spark
* components, including automatic selection of which [[Serializer]] to use for shuffles.
*/
-private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) {
+private[spark] class SerializerManager(
+ defaultSerializer: Serializer,
+ conf: SparkConf,
+ encryptionKey: Option[Array[Byte]]) {
+
+ def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None)
private[this] val kryoSerializer = new KryoSerializer(conf)
@@ -63,9 +68,6 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
// Whether to compress shuffle output temporarily spilled to disk
private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)
- // Whether to enable IO encryption
- private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED)
-
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
* the initialization of the compression codec until it is first used. The reason is that a Spark
* program could be using a user-defined codec in a third party jar, which is loaded in
@@ -125,14 +127,18 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
* Wrap an input stream for encryption if shuffle encryption is enabled
*/
private[this] def wrapForEncryption(s: InputStream): InputStream = {
- if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s
+ encryptionKey
+ .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) }
+ .getOrElse(s)
}
/**
* Wrap an output stream for encryption if shuffle encryption is enabled
*/
private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
- if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s
+ encryptionKey
+ .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) }
+ .getOrElse(s)
}
/**
diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
index 81eb907ac7..a61ec74c7d 100644
--- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
@@ -16,18 +16,21 @@
*/
package org.apache.spark.security
-import java.security.PrivilegedExceptionAction
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.nio.charset.StandardCharsets.UTF_8
+import java.util.UUID
-import org.apache.hadoop.security.{Credentials, UserGroupInformation}
+import com.google.common.io.ByteStreams
-import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark._
import org.apache.spark.internal.config._
import org.apache.spark.security.CryptoStreamUtils._
+import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
+import org.apache.spark.storage.TempShuffleBlockId
class CryptoStreamUtilsSuite extends SparkFunSuite {
- val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup"))
- test("Crypto configuration conversion") {
+ test("crypto configuration conversion") {
val sparkKey1 = s"${SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX}a.b.c"
val sparkVal1 = "val1"
val cryptoKey1 = s"${COMMONS_CRYPTO_CONF_PREFIX}a.b.c"
@@ -43,65 +46,85 @@ class CryptoStreamUtilsSuite extends SparkFunSuite {
assert(!props.containsKey(cryptoKey2))
}
- test("Shuffle encryption is disabled by default") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- initCredentials(conf, credentials)
- assert(credentials.getSecretKey(SPARK_IO_TOKEN) === null)
- }
- })
+ test("shuffle encryption key length should be 128 by default") {
+ val conf = createConf()
+ var key = CryptoStreamUtils.createKey(conf)
+ val actual = key.length * (java.lang.Byte.SIZE)
+ assert(actual === 128)
}
- test("Shuffle encryption key length should be 128 by default") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- conf.set(IO_ENCRYPTION_ENABLED, true)
- initCredentials(conf, credentials)
- var key = credentials.getSecretKey(SPARK_IO_TOKEN)
- assert(key !== null)
- val actual = key.length * (java.lang.Byte.SIZE)
- assert(actual === 128)
- }
- })
+ test("create 256-bit key") {
+ val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "256")
+ var key = CryptoStreamUtils.createKey(conf)
+ val actual = key.length * (java.lang.Byte.SIZE)
+ assert(actual === 256)
}
- test("Initial credentials with key length in 256") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 256)
- conf.set(IO_ENCRYPTION_ENABLED, true)
- initCredentials(conf, credentials)
- var key = credentials.getSecretKey(SPARK_IO_TOKEN)
- assert(key !== null)
- val actual = key.length * (java.lang.Byte.SIZE)
- assert(actual === 256)
- }
- })
+ test("create key with invalid length") {
+ intercept[IllegalArgumentException] {
+ val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "328")
+ CryptoStreamUtils.createKey(conf)
+ }
}
- test("Initial credentials with invalid key length") {
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val credentials = UserGroupInformation.getCurrentUser.getCredentials()
- val conf = new SparkConf()
- conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 328)
- conf.set(IO_ENCRYPTION_ENABLED, true)
- val thrown = intercept[IllegalArgumentException] {
- initCredentials(conf, credentials)
- }
- }
- })
+ test("serializer manager integration") {
+ val conf = createConf()
+ .set("spark.shuffle.compress", "true")
+ .set("spark.shuffle.spill.compress", "true")
+
+ val plainStr = "hello world"
+ val blockId = new TempShuffleBlockId(UUID.randomUUID())
+ val key = Some(CryptoStreamUtils.createKey(conf))
+ val serializerManager = new SerializerManager(new JavaSerializer(conf), conf,
+ encryptionKey = key)
+
+ val outputStream = new ByteArrayOutputStream()
+ val wrappedOutputStream = serializerManager.wrapStream(blockId, outputStream)
+ wrappedOutputStream.write(plainStr.getBytes(UTF_8))
+ wrappedOutputStream.close()
+
+ val encryptedBytes = outputStream.toByteArray
+ val encryptedStr = new String(encryptedBytes, UTF_8)
+ assert(plainStr !== encryptedStr)
+
+ val inputStream = new ByteArrayInputStream(encryptedBytes)
+ val wrappedInputStream = serializerManager.wrapStream(blockId, inputStream)
+ val decryptedBytes = ByteStreams.toByteArray(wrappedInputStream)
+ val decryptedStr = new String(decryptedBytes, UTF_8)
+ assert(decryptedStr === plainStr)
}
- private[this] def initCredentials(conf: SparkConf, credentials: Credentials): Unit = {
- if (conf.get(IO_ENCRYPTION_ENABLED)) {
- SecurityManager.initIOEncryptionKey(conf, credentials)
+ test("encryption key propagation to executors") {
+ val conf = createConf().setAppName("Crypto Test").setMaster("local-cluster[1,1,1024]")
+ val sc = new SparkContext(conf)
+ try {
+ val content = "This is the content to be encrypted."
+ val encrypted = sc.parallelize(Seq(1))
+ .map { str =>
+ val bytes = new ByteArrayOutputStream()
+ val out = CryptoStreamUtils.createCryptoOutputStream(bytes, SparkEnv.get.conf,
+ SparkEnv.get.securityManager.getIOEncryptionKey().get)
+ out.write(content.getBytes(UTF_8))
+ out.close()
+ bytes.toByteArray()
+ }.collect()(0)
+
+ assert(content != encrypted)
+
+ val in = CryptoStreamUtils.createCryptoInputStream(new ByteArrayInputStream(encrypted),
+ sc.conf, SparkEnv.get.securityManager.getIOEncryptionKey().get)
+ val decrypted = new String(ByteStreams.toByteArray(in), UTF_8)
+ assert(content === decrypted)
+ } finally {
+ sc.stop()
}
}
+
+ private def createConf(extra: (String, String)*): SparkConf = {
+ val conf = new SparkConf()
+ extra.foreach { case (k, v) => conf.set(k, v) }
+ conf.set(IO_ENCRYPTION_ENABLED, true)
+ conf
+ }
+
}
diff --git a/docs/configuration.md b/docs/configuration.md
index aa201c6b6a..d8800e93da 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -590,7 +590,8 @@ Apart from these, the following properties are also available, and may be useful
<td><code>spark.io.encryption.enabled</code></td>
<td>false</td>
<td>
- Enable IO encryption. Only supported in YARN mode.
+ Enable IO encryption. Currently supported by all modes except Mesos. It's recommended that RPC encryption
+ be enabled when using this feature.
</td>
</tr>
<tr>
diff --git a/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index 1937bd30ba..ee9149ce02 100644
--- a/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -75,7 +75,7 @@ private[spark] class MesosExecutorBackend
val conf = new SparkConf(loadDefaults = true).setAll(properties)
val port = conf.getInt("spark.executor.port", 0)
val env = SparkEnv.createExecutorEnv(
- conf, executorId, slaveInfo.getHostname, port, cpusPerTask, isLocal = false)
+ conf, executorId, slaveInfo.getHostname, port, cpusPerTask, None, isLocal = false)
executor = new Executor(
executorId,
diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala
index a849c4afa2..ed29b346ba 100644
--- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala
+++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler.cluster.mesos
import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.internal.config._
import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl}
/**
@@ -37,6 +38,9 @@ private[spark] class MesosClusterManager extends ExternalClusterManager {
override def createSchedulerBackend(sc: SparkContext,
masterURL: String,
scheduler: TaskScheduler): SchedulerBackend = {
+ require(!sc.conf.get(IO_ENCRYPTION_ENABLED),
+ "I/O encryption is currently not supported in Mesos.")
+
val mesosUrl = MESOS_REGEX.findFirstMatchIn(masterURL).get.group(1)
val coarse = sc.conf.getBoolean("spark.mesos.coarse", defaultValue = true)
if (coarse) {
diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManagerSuite.scala b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManagerSuite.scala
index 6fce06632c..a55855428b 100644
--- a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManagerSuite.scala
+++ b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManagerSuite.scala
@@ -17,7 +17,8 @@
package org.apache.spark.scheduler.cluster.mesos
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark._
+import org.apache.spark.internal.config._
class MesosClusterManagerSuite extends SparkFunSuite with LocalSparkContext {
def testURL(masterURL: String, expectedClass: Class[_], coarse: Boolean) {
@@ -44,4 +45,12 @@ class MesosClusterManagerSuite extends SparkFunSuite with LocalSparkContext {
classOf[MesosFineGrainedSchedulerBackend],
coarse = false)
}
+
+ test("mesos with i/o encryption throws error") {
+ val se = intercept[SparkException] {
+ val conf = new SparkConf().setAppName("test").set(IO_ENCRYPTION_ENABLED, true)
+ sc = new SparkContext("mesos", "test", conf)
+ }
+ assert(se.getCause().isInstanceOf[IllegalArgumentException])
+ }
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 1b75688b28..be419cee77 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -1014,12 +1014,7 @@ private[spark] class Client(
val securityManager = new SecurityManager(sparkConf)
amContainer.setApplicationACLs(
YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava)
-
- if (sparkConf.get(IO_ENCRYPTION_ENABLED)) {
- SecurityManager.initIOEncryptionKey(sparkConf, credentials)
- }
setupSecurityToken(amContainer)
-
amContainer
}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala
deleted file mode 100644
index 1c60315b21..0000000000
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/IOEncryptionSuite.scala
+++ /dev/null
@@ -1,108 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.yarn
-
-import java.io._
-import java.nio.charset.StandardCharsets
-import java.security.PrivilegedExceptionAction
-import java.util.UUID
-
-import org.apache.hadoop.security.{Credentials, UserGroupInformation}
-import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers}
-
-import org.apache.spark._
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.internal.config._
-import org.apache.spark.serializer._
-import org.apache.spark.storage._
-
-class IOEncryptionSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
- with BeforeAndAfterEach {
- private[this] val blockId = new TempShuffleBlockId(UUID.randomUUID())
- private[this] val conf = new SparkConf()
- private[this] val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup"))
- private[this] val serializer = new KryoSerializer(conf)
-
- override def beforeAll(): Unit = {
- System.setProperty("SPARK_YARN_MODE", "true")
- ugi.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- conf.set(IO_ENCRYPTION_ENABLED, true)
- val creds = new Credentials()
- SecurityManager.initIOEncryptionKey(conf, creds)
- SparkHadoopUtil.get.addCurrentUserCredentials(creds)
- }
- })
- }
-
- override def afterAll(): Unit = {
- SparkEnv.set(null)
- System.clearProperty("SPARK_YARN_MODE")
- }
-
- override def beforeEach(): Unit = {
- super.beforeEach()
- }
-
- override def afterEach(): Unit = {
- super.afterEach()
- conf.set("spark.shuffle.compress", false.toString)
- conf.set("spark.shuffle.spill.compress", false.toString)
- }
-
- test("IO encryption read and write") {
- ugi.doAs(new PrivilegedExceptionAction[Unit] {
- override def run(): Unit = {
- conf.set(IO_ENCRYPTION_ENABLED, true)
- conf.set("spark.shuffle.compress", false.toString)
- conf.set("spark.shuffle.spill.compress", false.toString)
- testYarnIOEncryptionWriteRead()
- }
- })
- }
-
- test("IO encryption read and write with shuffle compression enabled") {
- ugi.doAs(new PrivilegedExceptionAction[Unit] {
- override def run(): Unit = {
- conf.set(IO_ENCRYPTION_ENABLED, true)
- conf.set("spark.shuffle.compress", true.toString)
- conf.set("spark.shuffle.spill.compress", true.toString)
- testYarnIOEncryptionWriteRead()
- }
- })
- }
-
- private[this] def testYarnIOEncryptionWriteRead(): Unit = {
- val plainStr = "hello world"
- val outputStream = new ByteArrayOutputStream()
- val serializerManager = new SerializerManager(serializer, conf)
- val wrappedOutputStream = serializerManager.wrapStream(blockId, outputStream)
- wrappedOutputStream.write(plainStr.getBytes(StandardCharsets.UTF_8))
- wrappedOutputStream.close()
-
- val encryptedBytes = outputStream.toByteArray
- val encryptedStr = new String(encryptedBytes)
- assert(plainStr !== encryptedStr)
-
- val inputStream = new ByteArrayInputStream(encryptedBytes)
- val wrappedInputStream = serializerManager.wrapStream(blockId, inputStream)
- val decryptedBytes = new Array[Byte](1024)
- val len = wrappedInputStream.read(decryptedBytes)
- val decryptedStr = new String(decryptedBytes, 0, len, StandardCharsets.UTF_8)
- assert(decryptedStr === plainStr)
- }
-}