diff options
Diffstat (limited to 'core')
5 files changed, 564 insertions, 1 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index b81bfb3182..16423e771a 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -17,13 +17,15 @@ package org.apache.spark -import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConverters._ import scala.collection.mutable.LinkedHashSet import org.apache.avro.{Schema, SchemaNormalization} +import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils @@ -74,6 +76,16 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { this } + private[spark] def set[T](entry: ConfigEntry[T], value: T): SparkConf = { + set(entry.key, entry.stringConverter(value)) + this + } + + private[spark] def set[T](entry: OptionalConfigEntry[T], value: T): SparkConf = { + set(entry.key, entry.rawStringConverter(value)) + this + } + /** * The master URL to connect to, such as "local" to run locally with one thread, "local[4]" to * run locally with 4 cores, or "spark://master:7077" to run on a Spark standalone cluster. @@ -148,6 +160,20 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { this } + private[spark] def setIfMissing[T](entry: ConfigEntry[T], value: T): SparkConf = { + if (settings.putIfAbsent(entry.key, entry.stringConverter(value)) == null) { + logDeprecationWarning(entry.key) + } + this + } + + private[spark] def setIfMissing[T](entry: OptionalConfigEntry[T], value: T): SparkConf = { + if (settings.putIfAbsent(entry.key, entry.rawStringConverter(value)) == null) { + logDeprecationWarning(entry.key) + } + this + } + /** * Use Kryo serialization and register the given set of classes with Kryo. * If called multiple times, this will append the classes from all calls together. @@ -199,6 +225,17 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } /** + * Retrieves the value of a pre-defined configuration entry. + * + * - This is an internal Spark API. + * - The return type if defined by the configuration entry. + * - This will throw an exception is the config is not optional and the value is not set. + */ + private[spark] def get[T](entry: ConfigEntry[T]): T = { + entry.readFrom(this) + } + + /** * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then seconds are assumed. * @throws NoSuchElementException diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala new file mode 100644 index 0000000000..770b43697a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.config + +import java.util.concurrent.TimeUnit + +import org.apache.spark.network.util.{ByteUnit, JavaUtils} + +private object ConfigHelpers { + + def toNumber[T](s: String, converter: String => T, key: String, configType: String): T = { + try { + converter(s) + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"$key should be $configType, but was $s") + } + } + + def toBoolean(s: String, key: String): Boolean = { + try { + s.toBoolean + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException(s"$key should be boolean, but was $s") + } + } + + def stringToSeq[T](str: String, converter: String => T): Seq[T] = { + str.split(",").map(_.trim()).filter(_.nonEmpty).map(converter) + } + + def seqToString[T](v: Seq[T], stringConverter: T => String): String = { + v.map(stringConverter).mkString(",") + } + + def timeFromString(str: String, unit: TimeUnit): Long = JavaUtils.timeStringAs(str, unit) + + def timeToString(v: Long, unit: TimeUnit): String = TimeUnit.MILLISECONDS.convert(v, unit) + "ms" + + def byteFromString(str: String, unit: ByteUnit): Long = { + val (input, multiplier) = + if (str.length() > 0 && str.charAt(0) == '-') { + (str.substring(1), -1) + } else { + (str, 1) + } + multiplier * JavaUtils.byteStringAs(input, unit) + } + + def byteToString(v: Long, unit: ByteUnit): String = unit.convertTo(v, ByteUnit.BYTE) + "b" + +} + +/** + * A type-safe config builder. Provides methods for transforming the input data (which can be + * used, e.g., for validation) and creating the final config entry. + * + * One of the methods that return a [[ConfigEntry]] must be called to create a config entry that + * can be used with [[SparkConf]]. + */ +private[spark] class TypedConfigBuilder[T]( + val parent: ConfigBuilder, + val converter: String => T, + val stringConverter: T => String) { + + import ConfigHelpers._ + + def this(parent: ConfigBuilder, converter: String => T) = { + this(parent, converter, Option(_).map(_.toString).orNull) + } + + def transform(fn: T => T): TypedConfigBuilder[T] = { + new TypedConfigBuilder(parent, s => fn(converter(s)), stringConverter) + } + + def checkValues(validValues: Set[T]): TypedConfigBuilder[T] = { + transform { v => + if (!validValues.contains(v)) { + throw new IllegalArgumentException( + s"The value of ${parent.key} should be one of ${validValues.mkString(", ")}, but was $v") + } + v + } + } + + def toSequence: TypedConfigBuilder[Seq[T]] = { + new TypedConfigBuilder(parent, stringToSeq(_, converter), seqToString(_, stringConverter)) + } + + /** Creates a [[ConfigEntry]] that does not require a default value. */ + def optional: OptionalConfigEntry[T] = { + new OptionalConfigEntry[T](parent.key, converter, stringConverter, parent._doc, parent._public) + } + + /** Creates a [[ConfigEntry]] that has a default value. */ + def withDefault(default: T): ConfigEntry[T] = { + val transformedDefault = converter(stringConverter(default)) + new ConfigEntryWithDefault[T](parent.key, transformedDefault, converter, stringConverter, + parent._doc, parent._public) + } + + /** + * Creates a [[ConfigEntry]] that has a default value. The default value is provided as a + * [[String]] and must be a valid value for the entry. + */ + def withDefaultString(default: String): ConfigEntry[T] = { + val typedDefault = converter(default) + new ConfigEntryWithDefault[T](parent.key, typedDefault, converter, stringConverter, parent._doc, + parent._public) + } + +} + +/** + * Basic builder for Spark configurations. Provides methods for creating type-specific builders. + * + * @see TypedConfigBuilder + */ +private[spark] case class ConfigBuilder(key: String) { + + import ConfigHelpers._ + + var _public = true + var _doc = "" + + def internal: ConfigBuilder = { + _public = false + this + } + + def doc(s: String): ConfigBuilder = { + _doc = s + this + } + + def intConf: TypedConfigBuilder[Int] = { + new TypedConfigBuilder(this, toNumber(_, _.toInt, key, "int")) + } + + def longConf: TypedConfigBuilder[Long] = { + new TypedConfigBuilder(this, toNumber(_, _.toLong, key, "long")) + } + + def doubleConf: TypedConfigBuilder[Double] = { + new TypedConfigBuilder(this, toNumber(_, _.toDouble, key, "double")) + } + + def booleanConf: TypedConfigBuilder[Boolean] = { + new TypedConfigBuilder(this, toBoolean(_, key)) + } + + def stringConf: TypedConfigBuilder[String] = { + new TypedConfigBuilder(this, v => v) + } + + def timeConf(unit: TimeUnit): TypedConfigBuilder[Long] = { + new TypedConfigBuilder(this, timeFromString(_, unit), timeToString(_, unit)) + } + + def bytesConf(unit: ByteUnit): TypedConfigBuilder[Long] = { + new TypedConfigBuilder(this, byteFromString(_, unit), byteToString(_, unit)) + } + + def fallbackConf[T](fallback: ConfigEntry[T]): ConfigEntry[T] = { + new FallbackConfigEntry(key, _doc, _public, fallback) + } + +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala new file mode 100644 index 0000000000..f7296b487c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.config + +import org.apache.spark.SparkConf + +/** + * An entry contains all meta information for a configuration. + * + * @param key the key for the configuration + * @param defaultValue the default value for the configuration + * @param valueConverter how to convert a string to the value. It should throw an exception if the + * string does not have the required format. + * @param stringConverter how to convert a value to a string that the user can use it as a valid + * string value. It's usually `toString`. But sometimes, a custom converter + * is necessary. E.g., if T is List[String], `a, b, c` is better than + * `List(a, b, c)`. + * @param doc the documentation for the configuration + * @param isPublic if this configuration is public to the user. If it's `false`, this + * configuration is only used internally and we should not expose it to users. + * @tparam T the value type + */ +private[spark] abstract class ConfigEntry[T] ( + val key: String, + val valueConverter: String => T, + val stringConverter: T => String, + val doc: String, + val isPublic: Boolean) { + + def defaultValueString: String + + def readFrom(conf: SparkConf): T + + // This is used by SQLConf, since it doesn't use SparkConf to store settings and thus cannot + // use readFrom(). + def defaultValue: Option[T] = None + + override def toString: String = { + s"ConfigEntry(key=$key, defaultValue=$defaultValueString, doc=$doc, public=$isPublic)" + } +} + +private class ConfigEntryWithDefault[T] ( + key: String, + _defaultValue: T, + valueConverter: String => T, + stringConverter: T => String, + doc: String, + isPublic: Boolean) + extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + + override def defaultValue: Option[T] = Some(_defaultValue) + + override def defaultValueString: String = stringConverter(_defaultValue) + + override def readFrom(conf: SparkConf): T = { + conf.getOption(key).map(valueConverter).getOrElse(_defaultValue) + } + +} + +/** + * A config entry that does not have a default value. + */ +private[spark] class OptionalConfigEntry[T]( + key: String, + val rawValueConverter: String => T, + val rawStringConverter: T => String, + doc: String, + isPublic: Boolean) + extends ConfigEntry[Option[T]](key, s => Some(rawValueConverter(s)), + v => v.map(rawStringConverter).orNull, doc, isPublic) { + + override def defaultValueString: String = "<undefined>" + + override def readFrom(conf: SparkConf): Option[T] = conf.getOption(key).map(rawValueConverter) + +} + +/** + * A config entry whose default value is defined by another config entry. + */ +private class FallbackConfigEntry[T] ( + key: String, + doc: String, + isPublic: Boolean, + private val fallback: ConfigEntry[T]) + extends ConfigEntry[T](key, fallback.valueConverter, fallback.stringConverter, doc, isPublic) { + + override def defaultValueString: String = s"<value of ${fallback.key}>" + + override def readFrom(conf: SparkConf): T = { + conf.getOption(key).map(valueConverter).getOrElse(fallback.readFrom(conf)) + } + +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala new file mode 100644 index 0000000000..f2f20b3207 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal + +import org.apache.spark.launcher.SparkLauncher + +package object config { + + private[spark] val DRIVER_CLASS_PATH = + ConfigBuilder(SparkLauncher.DRIVER_EXTRA_CLASSPATH).stringConf.optional + + private[spark] val DRIVER_JAVA_OPTIONS = + ConfigBuilder(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS).stringConf.optional + + private[spark] val DRIVER_LIBRARY_PATH = + ConfigBuilder(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH).stringConf.optional + + private[spark] val DRIVER_USER_CLASS_PATH_FIRST = + ConfigBuilder("spark.driver.userClassPathFirst").booleanConf.withDefault(false) + + private[spark] val EXECUTOR_CLASS_PATH = + ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.optional + + private[spark] val EXECUTOR_JAVA_OPTIONS = + ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS).stringConf.optional + + private[spark] val EXECUTOR_LIBRARY_PATH = + ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_LIBRARY_PATH).stringConf.optional + + private[spark] val EXECUTOR_USER_CLASS_PATH_FIRST = + ConfigBuilder("spark.executor.userClassPathFirst").booleanConf.withDefault(false) + + private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal + .booleanConf.withDefault(false) + + private[spark] val CPUS_PER_TASK = ConfigBuilder("spark.task.cpus").intConf.withDefault(1) + + private[spark] val DYN_ALLOCATION_MIN_EXECUTORS = + ConfigBuilder("spark.dynamicAllocation.minExecutors").intConf.withDefault(0) + + private[spark] val DYN_ALLOCATION_INITIAL_EXECUTORS = + ConfigBuilder("spark.dynamicAllocation.initialExecutors") + .fallbackConf(DYN_ALLOCATION_MIN_EXECUTORS) + + private[spark] val DYN_ALLOCATION_MAX_EXECUTORS = + ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.withDefault(Int.MaxValue) + + private[spark] val SHUFFLE_SERVICE_ENABLED = + ConfigBuilder("spark.shuffle.service.enabled").booleanConf.withDefault(false) + + private[spark] val KEYTAB = ConfigBuilder("spark.yarn.keytab") + .doc("Location of user's keytab.") + .stringConf.optional + + private[spark] val PRINCIPAL = ConfigBuilder("spark.yarn.principal") + .doc("Name of the Kerberos principal.") + .stringConf.optional + + private[spark] val EXECUTOR_INSTANCES = ConfigBuilder("spark.executor.instances").intConf.optional + +} diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala new file mode 100644 index 0000000000..0644148eae --- /dev/null +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.config + +import java.util.concurrent.TimeUnit + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.network.util.ByteUnit + +class ConfigEntrySuite extends SparkFunSuite { + + test("conf entry: int") { + val conf = new SparkConf() + val iConf = ConfigBuilder("spark.int").intConf.withDefault(1) + assert(conf.get(iConf) === 1) + conf.set(iConf, 2) + assert(conf.get(iConf) === 2) + } + + test("conf entry: long") { + val conf = new SparkConf() + val lConf = ConfigBuilder("spark.long").longConf.withDefault(0L) + conf.set(lConf, 1234L) + assert(conf.get(lConf) === 1234L) + } + + test("conf entry: double") { + val conf = new SparkConf() + val dConf = ConfigBuilder("spark.double").doubleConf.withDefault(0.0) + conf.set(dConf, 20.0) + assert(conf.get(dConf) === 20.0) + } + + test("conf entry: boolean") { + val conf = new SparkConf() + val bConf = ConfigBuilder("spark.boolean").booleanConf.withDefault(false) + assert(!conf.get(bConf)) + conf.set(bConf, true) + assert(conf.get(bConf)) + } + + test("conf entry: optional") { + val conf = new SparkConf() + val optionalConf = ConfigBuilder("spark.optional").intConf.optional + assert(conf.get(optionalConf) === None) + conf.set(optionalConf, 1) + assert(conf.get(optionalConf) === Some(1)) + } + + test("conf entry: fallback") { + val conf = new SparkConf() + val parentConf = ConfigBuilder("spark.int").intConf.withDefault(1) + val confWithFallback = ConfigBuilder("spark.fallback").fallbackConf(parentConf) + assert(conf.get(confWithFallback) === 1) + conf.set(confWithFallback, 2) + assert(conf.get(parentConf) === 1) + assert(conf.get(confWithFallback) === 2) + } + + test("conf entry: time") { + val conf = new SparkConf() + val time = ConfigBuilder("spark.time").timeConf(TimeUnit.SECONDS).withDefaultString("1h") + assert(conf.get(time) === 3600L) + conf.set(time.key, "1m") + assert(conf.get(time) === 60L) + } + + test("conf entry: bytes") { + val conf = new SparkConf() + val bytes = ConfigBuilder("spark.bytes").bytesConf(ByteUnit.KiB).withDefaultString("1m") + assert(conf.get(bytes) === 1024L) + conf.set(bytes.key, "1k") + assert(conf.get(bytes) === 1L) + } + + test("conf entry: string seq") { + val conf = new SparkConf() + val seq = ConfigBuilder("spark.seq").stringConf.toSequence.withDefault(Seq()) + conf.set(seq.key, "1,,2, 3 , , 4") + assert(conf.get(seq) === Seq("1", "2", "3", "4")) + conf.set(seq, Seq("1", "2")) + assert(conf.get(seq) === Seq("1", "2")) + } + + test("conf entry: int seq") { + val conf = new SparkConf() + val seq = ConfigBuilder("spark.seq").intConf.toSequence.withDefault(Seq()) + conf.set(seq.key, "1,,2, 3 , , 4") + assert(conf.get(seq) === Seq(1, 2, 3, 4)) + conf.set(seq, Seq(1, 2)) + assert(conf.get(seq) === Seq(1, 2)) + } + + test("conf entry: transformation") { + val conf = new SparkConf() + val transformationConf = ConfigBuilder("spark.transformation") + .stringConf + .transform(_.toLowerCase()) + .withDefault("FOO") + + assert(conf.get(transformationConf) === "foo") + conf.set(transformationConf, "BAR") + assert(conf.get(transformationConf) === "bar") + } + + test("conf entry: valid values check") { + val conf = new SparkConf() + val enum = ConfigBuilder("spark.enum") + .stringConf + .checkValues(Set("a", "b", "c")) + .withDefault("a") + assert(conf.get(enum) === "a") + + conf.set(enum, "b") + assert(conf.get(enum) === "b") + + conf.set(enum, "d") + val enumError = intercept[IllegalArgumentException] { + conf.get(enum) + } + assert(enumError.getMessage === s"The value of ${enum.key} should be one of a, b, c, but was d") + } + + test("conf entry: conversion error") { + val conf = new SparkConf() + val conversionTest = ConfigBuilder("spark.conversionTest").doubleConf.optional + conf.set(conversionTest.key, "abc") + val conversionError = intercept[IllegalArgumentException] { + conf.get(conversionTest) + } + assert(conversionError.getMessage === s"${conversionTest.key} should be double, but was abc") + } + + test("default value handling is null-safe") { + val conf = new SparkConf() + val stringConf = ConfigBuilder("spark.string").stringConf.withDefault(null) + assert(conf.get(stringConf) === null) + } + +} |