aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala121
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala78
3 files changed, 215 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 80e2c1986d..2770552050 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -457,6 +457,18 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val VARIABLE_SUBSTITUTE_ENABLED =
+ SQLConfigBuilder("spark.sql.variable.substitute")
+ .doc("This enables substitution using syntax like ${var} ${system:var} and ${env:var}.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val VARIABLE_SUBSTITUTE_DEPTH =
+ SQLConfigBuilder("spark.sql.variable.substitute.depth")
+ .doc("The maximum replacements the substitution engine will do.")
+ .intConf
+ .createWithDefault(40)
+
// TODO: This is still WIP and shouldn't be turned on without extensive test coverage
val COLUMNAR_AGGREGATE_MAP_ENABLED = SQLConfigBuilder("spark.sql.codegen.aggregate.map.enabled")
.internal()
@@ -615,6 +627,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def columnarAggregateMapEnabled: Boolean = getConf(COLUMNAR_AGGREGATE_MAP_ENABLED)
+ def variableSubstituteEnabled: Boolean = getConf(VARIABLE_SUBSTITUTE_ENABLED)
+
+ def variableSubstituteDepth: Int = getConf(VARIABLE_SUBSTITUTE_DEPTH)
+
override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)
override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala
new file mode 100644
index 0000000000..0982f1d687
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal
+
+import java.util.regex.Pattern
+
+import org.apache.spark.sql.AnalysisException
+
+/**
+ * A helper class that enables substitution using syntax like
+ * `${var}`, `${system:var}` and `${env:var}`.
+ *
+ * Variable substitution is controlled by [[SQLConf.variableSubstituteEnabled]].
+ */
+class VariableSubstitution(conf: SQLConf) {
+
+ private val pattern = Pattern.compile("\\$\\{[^\\}\\$ ]+\\}")
+
+ /**
+ * Given a query, does variable substitution and return the result.
+ */
+ def substitute(input: String): String = {
+ // Note that this function is mostly copied from Hive's SystemVariables, so the style is
+ // very Java/Hive like.
+ if (input eq null) {
+ return null
+ }
+
+ if (!conf.variableSubstituteEnabled) {
+ return input
+ }
+
+ var eval = input
+ val depth = conf.variableSubstituteDepth
+ val builder = new StringBuilder
+ val m = pattern.matcher("")
+
+ var s = 0
+ while (s <= depth) {
+ m.reset(eval)
+ builder.setLength(0)
+
+ var prev = 0
+ var found = false
+ while (m.find(prev)) {
+ val group = m.group()
+ var substitute = substituteVariable(group.substring(2, group.length - 1))
+ if (substitute.isEmpty) {
+ substitute = group
+ } else {
+ found = true
+ }
+ builder.append(eval.substring(prev, m.start())).append(substitute)
+ prev = m.end()
+ }
+
+ if (!found) {
+ return eval
+ }
+
+ builder.append(eval.substring(prev))
+ eval = builder.toString
+ s += 1
+ }
+
+ if (s > depth) {
+ throw new AnalysisException(
+ "Variable substitution depth is deeper than " + depth + " for input " + input)
+ } else {
+ return eval
+ }
+ }
+
+ /**
+ * Given a variable, replaces with the substitute value (default to "").
+ */
+ private def substituteVariable(variable: String): String = {
+ var value: String = null
+
+ if (variable.startsWith("system:")) {
+ value = System.getProperty(variable.substring("system:".length()))
+ }
+
+ if (value == null && variable.startsWith("env:")) {
+ value = System.getenv(variable.substring("env:".length()))
+ }
+
+ if (value == null && conf != null && variable.startsWith("hiveconf:")) {
+ value = conf.getConfString(variable.substring("hiveconf:".length()), "")
+ }
+
+ if (value == null && conf != null && variable.startsWith("sparkconf:")) {
+ value = conf.getConfString(variable.substring("sparkconf:".length()), "")
+ }
+
+ if (value == null && conf != null && variable.startsWith("spark:")) {
+ value = conf.getConfString(variable.substring("spark:".length()), "")
+ }
+
+ if (value == null && conf != null) {
+ value = conf.getConfString(variable, "")
+ }
+
+ value
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala
new file mode 100644
index 0000000000..deac95918b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+
+class VariableSubstitutionSuite extends SparkFunSuite {
+
+ private lazy val conf = new SQLConf
+ private lazy val sub = new VariableSubstitution(conf)
+
+ test("system property") {
+ System.setProperty("varSubSuite.var", "abcd")
+ assert(sub.substitute("${system:varSubSuite.var}") == "abcd")
+ }
+
+ test("environmental variables") {
+ assert(sub.substitute("${env:SPARK_TESTING}") == "1")
+ }
+
+ test("Spark configuration variable") {
+ conf.setConfString("some-random-string-abcd", "1234abcd")
+ assert(sub.substitute("${hiveconf:some-random-string-abcd}") == "1234abcd")
+ assert(sub.substitute("${sparkconf:some-random-string-abcd}") == "1234abcd")
+ assert(sub.substitute("${spark:some-random-string-abcd}") == "1234abcd")
+ assert(sub.substitute("${some-random-string-abcd}") == "1234abcd")
+ }
+
+ test("multiple substitutes") {
+ val q = "select ${bar} ${foo} ${doo} this is great"
+ conf.setConfString("bar", "1")
+ conf.setConfString("foo", "2")
+ conf.setConfString("doo", "3")
+ assert(sub.substitute(q) == "select 1 2 3 this is great")
+ }
+
+ test("test nested substitutes") {
+ val q = "select ${bar} ${foo} this is great"
+ conf.setConfString("bar", "1")
+ conf.setConfString("foo", "${bar}")
+ assert(sub.substitute(q) == "select 1 1 this is great")
+ }
+
+ test("depth limit") {
+ val q = "select ${bar} ${foo} ${doo}"
+ conf.setConfString(SQLConf.VARIABLE_SUBSTITUTE_DEPTH.key, "2")
+
+ // This should be OK since it is not nested.
+ conf.setConfString("bar", "1")
+ conf.setConfString("foo", "2")
+ conf.setConfString("doo", "3")
+ assert(sub.substitute(q) == "select 1 2 3")
+
+ // This should not be OK since it is nested in 3 levels.
+ conf.setConfString("bar", "1")
+ conf.setConfString("foo", "${bar}")
+ conf.setConfString("doo", "${foo}")
+ intercept[AnalysisException] {
+ sub.substitute(q)
+ }
+ }
+}