diff options
Diffstat (limited to 'core/src/main/scala')
8 files changed, 70 insertions, 55 deletions
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 199365ad92..87fe563152 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -21,7 +21,6 @@ import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} import java.security.{KeyStore, SecureRandom} import java.security.cert.X509Certificate -import javax.crypto.KeyGenerator import javax.net.ssl._ import com.google.common.hash.HashCodes @@ -33,7 +32,6 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.sasl.SecretKeyHolder -import org.apache.spark.security.CryptoStreamUtils._ import org.apache.spark.util.Utils /** @@ -185,7 +183,9 @@ import org.apache.spark.util.Utils * setting `spark.ssl.useNodeLocalConf` to `true`. */ -private[spark] class SecurityManager(sparkConf: SparkConf) +private[spark] class SecurityManager( + sparkConf: SparkConf, + ioEncryptionKey: Option[Array[Byte]] = None) extends Logging with SecretKeyHolder { import SecurityManager._ @@ -415,6 +415,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) logInfo("Changing acls enabled to: " + aclsOn) } + def getIOEncryptionKey(): Option[Array[Byte]] = ioEncryptionKey + /** * Generates or looks up the secret key. * @@ -559,19 +561,4 @@ private[spark] object SecurityManager { // key used to store the spark secret in the Hadoop UGI val SECRET_LOOKUP_KEY = "sparkCookie" - /** - * Setup the cryptographic key used by IO encryption in credentials. The key is generated using - * [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]]. - */ - def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = { - if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) { - val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS) - val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM) - val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm) - keyGen.init(keyLen) - - val ioKey = keyGen.generateKey() - credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded) - } - } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1261e3e735..a159a170eb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -422,10 +422,6 @@ class SparkContext(config: SparkConf) extends Logging { } if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") - if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) { - throw new SparkException("IO encryption is only supported in YARN mode, please disable it " + - s"by setting ${IO_ENCRYPTION_ENABLED.key} to false") - } // "_jobProgressListener" should be set up before creating SparkEnv because when creating // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 1ffeb12988..1296386ac9 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -36,6 +36,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ @@ -165,15 +166,20 @@ object SparkEnv extends Logging { val bindAddress = conf.get(DRIVER_BIND_ADDRESS) val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS) val port = conf.get("spark.driver.port").toInt + val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) { + Some(CryptoStreamUtils.createKey(conf)) + } else { + None + } create( conf, SparkContext.DRIVER_IDENTIFIER, bindAddress, advertiseAddress, port, - isDriver = true, - isLocal = isLocal, - numUsableCores = numCores, + isLocal, + numCores, + ioEncryptionKey, listenerBus = listenerBus, mockOutputCommitCoordinator = mockOutputCommitCoordinator ) @@ -189,6 +195,7 @@ object SparkEnv extends Logging { hostname: String, port: Int, numCores: Int, + ioEncryptionKey: Option[Array[Byte]], isLocal: Boolean): SparkEnv = { val env = create( conf, @@ -196,9 +203,9 @@ object SparkEnv extends Logging { hostname, hostname, port, - isDriver = false, - isLocal = isLocal, - numUsableCores = numCores + isLocal, + numCores, + ioEncryptionKey ) SparkEnv.set(env) env @@ -213,18 +220,26 @@ object SparkEnv extends Logging { bindAddress: String, advertiseAddress: String, port: Int, - isDriver: Boolean, isLocal: Boolean, numUsableCores: Int, + ioEncryptionKey: Option[Array[Byte]], listenerBus: LiveListenerBus = null, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { + val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER + // Listener bus is only used on the driver if (isDriver) { assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!") } - val securityManager = new SecurityManager(conf) + val securityManager = new SecurityManager(conf, ioEncryptionKey) + ioEncryptionKey.foreach { _ => + if (!securityManager.isSaslEncryptionEnabled()) { + logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " + + "wire.") + } + } val systemName = if (isDriver) driverSystemName else executorSystemName val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf, @@ -270,7 +285,7 @@ object SparkEnv extends Logging { "spark.serializer", "org.apache.spark.serializer.JavaSerializer") logDebug(s"Using serializer: ${serializer.getClass}") - val serializerManager = new SerializerManager(serializer, conf) + val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey) val closureSerializer = new JavaSerializer(conf) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 7eec4ae64f..92a27902c6 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -200,8 +200,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { new SecurityManager(executorConf), clientMode = true) val driver = fetcher.setupEndpointRefByURI(driverUrl) - val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++ - Seq[(String, String)](("spark.app.id", appId)) + val cfg = driver.askWithRetry[SparkAppConfig](RetrieveSparkAppConfig) + val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() // Create SparkEnv using properties we fetched from the driver. @@ -221,7 +221,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } val env = SparkEnv.createExecutorEnv( - driverConf, executorId, hostname, port, cores, isLocal = false) + driverConf, executorId, hostname, port, cores, cfg.ioEncryptionKey, isLocal = false) env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index edc8aac5d1..0a4f19d760 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -28,7 +28,12 @@ private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable private[spark] object CoarseGrainedClusterMessages { - case object RetrieveSparkProps extends CoarseGrainedClusterMessage + case object RetrieveSparkAppConfig extends CoarseGrainedClusterMessage + + case class SparkAppConfig( + sparkProperties: Seq[(String, String)], + ioEncryptionKey: Option[Array[Byte]]) + extends CoarseGrainedClusterMessage case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 10d55c87fb..3452487e72 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -206,8 +206,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp removeExecutor(executorId, reason) context.reply(true) - case RetrieveSparkProps => - context.reply(sparkProperties) + case RetrieveSparkAppConfig => + val reply = SparkAppConfig(sparkProperties, + SparkEnv.get.securityManager.getIOEncryptionKey()) + context.reply(reply) } // Make fake resource offers on all executors diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index f41fc38be2..8e3436f134 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -18,14 +18,13 @@ package org.apache.spark.security import java.io.{InputStream, OutputStream} import java.util.Properties +import javax.crypto.KeyGenerator import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} import org.apache.commons.crypto.random._ import org.apache.commons.crypto.stream._ -import org.apache.hadoop.io.Text import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -33,10 +32,6 @@ import org.apache.spark.internal.config._ * A util class for manipulating IO encryption and decryption streams. */ private[spark] object CryptoStreamUtils extends Logging { - /** - * Constants and variables for spark IO encryption - */ - val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN") // The initialization vector length in bytes. val IV_LENGTH_IN_BYTES = 16 @@ -50,12 +45,11 @@ private[spark] object CryptoStreamUtils extends Logging { */ def createCryptoOutputStream( os: OutputStream, - sparkConf: SparkConf): OutputStream = { + sparkConf: SparkConf, + key: Array[Byte]): OutputStream = { val properties = toCryptoConf(sparkConf) val iv = createInitializationVector(properties) os.write(iv) - val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() - val key = credentials.getSecretKey(SPARK_IO_TOKEN) val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) new CryptoOutputStream(transformationStr, properties, os, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) @@ -66,12 +60,11 @@ private[spark] object CryptoStreamUtils extends Logging { */ def createCryptoInputStream( is: InputStream, - sparkConf: SparkConf): InputStream = { + sparkConf: SparkConf, + key: Array[Byte]): InputStream = { val properties = toCryptoConf(sparkConf) val iv = new Array[Byte](IV_LENGTH_IN_BYTES) is.read(iv, 0, iv.length) - val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() - val key = credentials.getSecretKey(SPARK_IO_TOKEN) val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) new CryptoInputStream(transformationStr, properties, is, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) @@ -92,6 +85,17 @@ private[spark] object CryptoStreamUtils extends Logging { } /** + * Creates a new encryption key. + */ + def createKey(conf: SparkConf): Array[Byte] = { + val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS) + val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM) + val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm) + keyGen.init(keyLen) + keyGen.generateKey().getEncoded() + } + + /** * This method to generate an IV (Initialization Vector) using secure random. */ private[this] def createInitializationVector(properties: Properties): Array[Byte] = { diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 2156d576f1..ef8432ec08 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -33,7 +33,12 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea * Component which configures serialization, compression and encryption for various Spark * components, including automatic selection of which [[Serializer]] to use for shuffles. */ -private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) { +private[spark] class SerializerManager( + defaultSerializer: Serializer, + conf: SparkConf, + encryptionKey: Option[Array[Byte]]) { + + def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None) private[this] val kryoSerializer = new KryoSerializer(conf) @@ -63,9 +68,6 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar // Whether to compress shuffle output temporarily spilled to disk private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - // Whether to enable IO encryption - private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED) - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark * program could be using a user-defined codec in a third party jar, which is loaded in @@ -125,14 +127,18 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar * Wrap an input stream for encryption if shuffle encryption is enabled */ private[this] def wrapForEncryption(s: InputStream): InputStream = { - if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s + encryptionKey + .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) } + .getOrElse(s) } /** * Wrap an output stream for encryption if shuffle encryption is enabled */ private[this] def wrapForEncryption(s: OutputStream): OutputStream = { - if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s + encryptionKey + .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) } + .getOrElse(s) } /** |