aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-05-12 22:18:39 -0700
committerYin Huai <yhuai@databricks.com>2016-05-12 22:18:39 -0700
commiteda2800d44843b6478e22d2c99bca4af7e9c9613 (patch)
tree5c0a9eb8fac7f45cae1f0b6b372b4341e7b84606
parentba169c3230e7d6cb192ec4bd567a1fef7b93b29f (diff)
downloadspark-eda2800d44843b6478e22d2c99bca4af7e9c9613.tar.gz
spark-eda2800d44843b6478e22d2c99bca4af7e9c9613.tar.bz2
spark-eda2800d44843b6478e22d2c99bca4af7e9c9613.zip
[SPARK-14541][SQL] Support IFNULL, NULLIF, NVL and NVL2
## What changes were proposed in this pull request? This patch adds support for a few SQL functions to improve compatibility with other databases: IFNULL, NULLIF, NVL and NVL2. In order to do this, this patch introduced a RuntimeReplaceable expression trait that allows replacing an unevaluable expression in the optimizer before evaluation. Note that the semantics are not completely identical to other databases in esoteric cases. ## How was this patch tested? Added a new test suite SQLCompatibilityFunctionSuite. Closes #12373. Author: Reynold Xin <rxin@databricks.com> Closes #13084 from rxin/SPARK-14541.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala78
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala72
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala1
8 files changed, 194 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index c459fe5878..eca837ccf0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -165,13 +165,16 @@ object FunctionRegistry {
expression[Greatest]("greatest"),
expression[If]("if"),
expression[IsNaN]("isnan"),
+ expression[IfNull]("ifnull"),
expression[IsNull]("isnull"),
expression[IsNotNull]("isnotnull"),
expression[Least]("least"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[NaNvl]("nanvl"),
- expression[Coalesce]("nvl"),
+ expression[NullIf]("nullif"),
+ expression[Nvl]("nvl"),
+ expression[Nvl2]("nvl2"),
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[CreateStruct]("struct"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 8319ec0a82..537dda60af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -521,6 +521,8 @@ object HiveTypeCoercion {
NaNvl(l, Cast(r, DoubleType))
case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
NaNvl(Cast(l, DoubleType), r)
+
+ case e: RuntimeReplaceable => e.replaceForTypeCoercion()
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index c26faee2f4..fab163476f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -222,6 +222,33 @@ trait Unevaluable extends Expression {
/**
+ * An expression that gets replaced at runtime (currently by the optimizer) into a different
+ * expression for evaluation. This is mainly used to provide compatibility with other databases.
+ * For example, we use this to support "nvl" by replacing it with "coalesce".
+ */
+trait RuntimeReplaceable extends Unevaluable {
+ /**
+ * Method for concrete implementations to override that specifies how to construct the expression
+ * that should replace the current one.
+ */
+ def replaceForEvaluation(): Expression
+
+ /**
+ * Method for concrete implementations to override that specifies how to coerce the input types.
+ */
+ def replaceForTypeCoercion(): Expression
+
+ /** The expression that should be used during evaluation. */
+ lazy val replaced: Expression = replaceForEvaluation()
+
+ override def nullable: Boolean = replaced.nullable
+ override def foldable: Boolean = replaced.foldable
+ override def dataType: DataType = replaced.dataType
+ override def checkInputDataTypes(): TypeCheckResult = replaced.checkInputDataTypes()
+}
+
+
+/**
* Expressions that don't have SQL representation should extend this trait. Examples are
* `ScalaUDF`, `ScalaUDAF`, and object expressions like `MapObjects` and `Invoke`.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 421200e147..641c81b247 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.{HiveTypeCoercion, TypeCheckResult}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -88,6 +88,82 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
+@ExpressionDescription(usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.")
+case class IfNull(left: Expression, right: Expression) extends RuntimeReplaceable {
+ override def children: Seq[Expression] = Seq(left, right)
+
+ override def replaceForEvaluation(): Expression = Coalesce(Seq(left, right))
+
+ override def replaceForTypeCoercion(): Expression = {
+ if (left.dataType != right.dataType) {
+ HiveTypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype =>
+ copy(left = Cast(left, dtype), right = Cast(right, dtype))
+ }.getOrElse(this)
+ } else {
+ this
+ }
+ }
+}
+
+
+@ExpressionDescription(usage = "_FUNC_(a,b) - Returns null if a equals to b, or a otherwise.")
+case class NullIf(left: Expression, right: Expression) extends RuntimeReplaceable {
+ override def children: Seq[Expression] = Seq(left, right)
+
+ override def replaceForEvaluation(): Expression = {
+ If(EqualTo(left, right), Literal.create(null, left.dataType), left)
+ }
+
+ override def replaceForTypeCoercion(): Expression = {
+ if (left.dataType != right.dataType) {
+ HiveTypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype =>
+ copy(left = Cast(left, dtype), right = Cast(right, dtype))
+ }.getOrElse(this)
+ } else {
+ this
+ }
+ }
+}
+
+
+@ExpressionDescription(usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.")
+case class Nvl(left: Expression, right: Expression) extends RuntimeReplaceable {
+ override def children: Seq[Expression] = Seq(left, right)
+
+ override def replaceForEvaluation(): Expression = Coalesce(Seq(left, right))
+
+ override def replaceForTypeCoercion(): Expression = {
+ if (left.dataType != right.dataType) {
+ HiveTypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype =>
+ copy(left = Cast(left, dtype), right = Cast(right, dtype))
+ }.getOrElse(this)
+ } else {
+ this
+ }
+ }
+}
+
+
+@ExpressionDescription(usage = "_FUNC_(a,b,c) - Returns b if a is not null, or c otherwise.")
+case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression)
+ extends RuntimeReplaceable {
+
+ override def replaceForEvaluation(): Expression = If(IsNotNull(expr1), expr2, expr3)
+
+ override def children: Seq[Expression] = Seq(expr1, expr2, expr3)
+
+ override def replaceForTypeCoercion(): Expression = {
+ if (expr2.dataType != expr3.dataType) {
+ HiveTypeCoercion.findTightestCommonTypeOfTwo(expr2.dataType, expr3.dataType).map { dtype =>
+ copy(expr2 = Cast(expr2, dtype), expr3 = Cast(expr3, dtype))
+ }.getOrElse(this)
+ } else {
+ this
+ }
+ }
+}
+
+
/**
* Evaluates to `true` iff it's NaN.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 928ba213b5..af7532e0c0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -49,6 +49,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
// we do not eliminate subqueries or compute current time in the analyzer.
Batch("Finish Analysis", Once,
EliminateSubqueryAliases,
+ ReplaceExpressions,
ComputeCurrentTime,
GetCurrentDatabase(sessionCatalog),
DistinctAggregationRewriter) ::
@@ -1512,6 +1513,17 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
}
/**
+ * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can
+ * be evaluated. This is mainly used to provide compatibility with other databases.
+ * For example, we use this to support "nvl" by replacing it with "coalesce".
+ */
+object ReplaceExpressions extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case e: RuntimeReplaceable => e.replaced
+ }
+}
+
+/**
* Computes the current date and time to make sure we return the same result in a single query.
*/
object ComputeCurrentTime extends Rule[LogicalPlan] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 746e25a0c3..73d77651a0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -152,12 +152,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Row("one", "not_one"))
}
- test("nvl function") {
- checkAnswer(
- sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
- Row("x", "y", null))
- }
-
test("misc md5 function") {
val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b")
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala
new file mode 100644
index 0000000000..1e3239550f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+/**
+ * A test suite for functions added for compatibility with other databases such as Oracle, MSSQL.
+ * These functions are typically implemented using the trait
+ * [[org.apache.spark.sql.catalyst.expressions.RuntimeReplaceable]].
+ */
+class SQLCompatibilityFunctionSuite extends QueryTest with SharedSQLContext {
+
+ test("ifnull") {
+ checkAnswer(
+ sql("SELECT ifnull(null, 'x'), ifnull('y', 'x'), ifnull(null, null)"),
+ Row("x", "y", null))
+
+ // Type coercion
+ checkAnswer(
+ sql("SELECT ifnull(1, 2.1d), ifnull(null, 2.1d)"),
+ Row(1.0, 2.1))
+ }
+
+ test("nullif") {
+ checkAnswer(
+ sql("SELECT nullif('x', 'x'), nullif('x', 'y')"),
+ Row(null, "x"))
+
+ // Type coercion
+ checkAnswer(
+ sql("SELECT nullif(1, 2.1d), nullif(1, 1.0d)"),
+ Row(1.0, null))
+ }
+
+ test("nvl") {
+ checkAnswer(
+ sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
+ Row("x", "y", null))
+
+ // Type coercion
+ checkAnswer(
+ sql("SELECT nvl(1, 2.1d), nvl(null, 2.1d)"),
+ Row(1.0, 2.1))
+ }
+
+ test("nvl2") {
+ checkAnswer(
+ sql("SELECT nvl2(null, 'x', 'y'), nvl2('n', 'x', 'y'), nvl2(null, null, null)"),
+ Row("y", "x", null))
+
+ // Type coercion
+ checkAnswer(
+ sql("SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d)"),
+ Row(2.1, 1.0))
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
index 72736ee55b..b4eb50e331 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
@@ -102,7 +102,6 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
checkSqlGeneration("SELECT map(1, 'a', 2, 'b')")
checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)")
checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2")
- checkSqlGeneration("SELECT nvl(null, 1, 2)")
checkSqlGeneration("SELECT rand(1)")
checkSqlGeneration("SELECT randn(3)")
checkSqlGeneration("SELECT struct(1,2,3)")