aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-05-12 22:18:39 -0700
committerYin Huai <yhuai@databricks.com>2016-05-12 22:18:39 -0700
commiteda2800d44843b6478e22d2c99bca4af7e9c9613 (patch)
tree5c0a9eb8fac7f45cae1f0b6b372b4341e7b84606 /sql/catalyst
parentba169c3230e7d6cb192ec4bd567a1fef7b93b29f (diff)
downloadspark-eda2800d44843b6478e22d2c99bca4af7e9c9613.tar.gz
spark-eda2800d44843b6478e22d2c99bca4af7e9c9613.tar.bz2
spark-eda2800d44843b6478e22d2c99bca4af7e9c9613.zip
[SPARK-14541][SQL] Support IFNULL, NULLIF, NVL and NVL2
## What changes were proposed in this pull request? This patch adds support for a few SQL functions to improve compatibility with other databases: IFNULL, NULLIF, NVL and NVL2. In order to do this, this patch introduced a RuntimeReplaceable expression trait that allows replacing an unevaluable expression in the optimizer before evaluation. Note that the semantics are not completely identical to other databases in esoteric cases. ## How was this patch tested? Added a new test suite SQLCompatibilityFunctionSuite. Closes #12373. Author: Reynold Xin <rxin@databricks.com> Closes #13084 from rxin/SPARK-14541.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala78
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala12
5 files changed, 122 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index c459fe5878..eca837ccf0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -165,13 +165,16 @@ object FunctionRegistry {
expression[Greatest]("greatest"),
expression[If]("if"),
expression[IsNaN]("isnan"),
+ expression[IfNull]("ifnull"),
expression[IsNull]("isnull"),
expression[IsNotNull]("isnotnull"),
expression[Least]("least"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[NaNvl]("nanvl"),
- expression[Coalesce]("nvl"),
+ expression[NullIf]("nullif"),
+ expression[Nvl]("nvl"),
+ expression[Nvl2]("nvl2"),
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[CreateStruct]("struct"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 8319ec0a82..537dda60af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -521,6 +521,8 @@ object HiveTypeCoercion {
NaNvl(l, Cast(r, DoubleType))
case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
NaNvl(Cast(l, DoubleType), r)
+
+ case e: RuntimeReplaceable => e.replaceForTypeCoercion()
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index c26faee2f4..fab163476f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -222,6 +222,33 @@ trait Unevaluable extends Expression {
/**
+ * An expression that gets replaced at runtime (currently by the optimizer) into a different
+ * expression for evaluation. This is mainly used to provide compatibility with other databases.
+ * For example, we use this to support "nvl" by replacing it with "coalesce".
+ */
+trait RuntimeReplaceable extends Unevaluable {
+ /**
+ * Method for concrete implementations to override that specifies how to construct the expression
+ * that should replace the current one.
+ */
+ def replaceForEvaluation(): Expression
+
+ /**
+ * Method for concrete implementations to override that specifies how to coerce the input types.
+ */
+ def replaceForTypeCoercion(): Expression
+
+ /** The expression that should be used during evaluation. */
+ lazy val replaced: Expression = replaceForEvaluation()
+
+ override def nullable: Boolean = replaced.nullable
+ override def foldable: Boolean = replaced.foldable
+ override def dataType: DataType = replaced.dataType
+ override def checkInputDataTypes(): TypeCheckResult = replaced.checkInputDataTypes()
+}
+
+
+/**
* Expressions that don't have SQL representation should extend this trait. Examples are
* `ScalaUDF`, `ScalaUDAF`, and object expressions like `MapObjects` and `Invoke`.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 421200e147..641c81b247 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.{HiveTypeCoercion, TypeCheckResult}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -88,6 +88,82 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
+@ExpressionDescription(usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.")
+case class IfNull(left: Expression, right: Expression) extends RuntimeReplaceable {
+ override def children: Seq[Expression] = Seq(left, right)
+
+ override def replaceForEvaluation(): Expression = Coalesce(Seq(left, right))
+
+ override def replaceForTypeCoercion(): Expression = {
+ if (left.dataType != right.dataType) {
+ HiveTypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype =>
+ copy(left = Cast(left, dtype), right = Cast(right, dtype))
+ }.getOrElse(this)
+ } else {
+ this
+ }
+ }
+}
+
+
+@ExpressionDescription(usage = "_FUNC_(a,b) - Returns null if a equals to b, or a otherwise.")
+case class NullIf(left: Expression, right: Expression) extends RuntimeReplaceable {
+ override def children: Seq[Expression] = Seq(left, right)
+
+ override def replaceForEvaluation(): Expression = {
+ If(EqualTo(left, right), Literal.create(null, left.dataType), left)
+ }
+
+ override def replaceForTypeCoercion(): Expression = {
+ if (left.dataType != right.dataType) {
+ HiveTypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype =>
+ copy(left = Cast(left, dtype), right = Cast(right, dtype))
+ }.getOrElse(this)
+ } else {
+ this
+ }
+ }
+}
+
+
+@ExpressionDescription(usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.")
+case class Nvl(left: Expression, right: Expression) extends RuntimeReplaceable {
+ override def children: Seq[Expression] = Seq(left, right)
+
+ override def replaceForEvaluation(): Expression = Coalesce(Seq(left, right))
+
+ override def replaceForTypeCoercion(): Expression = {
+ if (left.dataType != right.dataType) {
+ HiveTypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype =>
+ copy(left = Cast(left, dtype), right = Cast(right, dtype))
+ }.getOrElse(this)
+ } else {
+ this
+ }
+ }
+}
+
+
+@ExpressionDescription(usage = "_FUNC_(a,b,c) - Returns b if a is not null, or c otherwise.")
+case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression)
+ extends RuntimeReplaceable {
+
+ override def replaceForEvaluation(): Expression = If(IsNotNull(expr1), expr2, expr3)
+
+ override def children: Seq[Expression] = Seq(expr1, expr2, expr3)
+
+ override def replaceForTypeCoercion(): Expression = {
+ if (expr2.dataType != expr3.dataType) {
+ HiveTypeCoercion.findTightestCommonTypeOfTwo(expr2.dataType, expr3.dataType).map { dtype =>
+ copy(expr2 = Cast(expr2, dtype), expr3 = Cast(expr3, dtype))
+ }.getOrElse(this)
+ } else {
+ this
+ }
+ }
+}
+
+
/**
* Evaluates to `true` iff it's NaN.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 928ba213b5..af7532e0c0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -49,6 +49,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
// we do not eliminate subqueries or compute current time in the analyzer.
Batch("Finish Analysis", Once,
EliminateSubqueryAliases,
+ ReplaceExpressions,
ComputeCurrentTime,
GetCurrentDatabase(sessionCatalog),
DistinctAggregationRewriter) ::
@@ -1512,6 +1513,17 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
}
/**
+ * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can
+ * be evaluated. This is mainly used to provide compatibility with other databases.
+ * For example, we use this to support "nvl" by replacing it with "coalesce".
+ */
+object ReplaceExpressions extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case e: RuntimeReplaceable => e.replaced
+ }
+}
+
+/**
* Computes the current date and time to make sure we return the same result in a single query.
*/
object ComputeCurrentTime extends Rule[LogicalPlan] {