aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaoyuan Wang <daoyuan.wang@intel.com>2014-12-02 14:21:12 -0800
committerMichael Armbrust <michael@databricks.com>2014-12-02 14:22:44 -0800
commit97dc2384ad4cb555200bbe994b5470f81fe4671f (patch)
tree74b97b2fe0e0903e27b96073c88a7acb46d6edde
parent06129cde4dc035b31fcd8e5870a2030be2f2a8b7 (diff)
downloadspark-97dc2384ad4cb555200bbe994b5470f81fe4671f.tar.gz
spark-97dc2384ad4cb555200bbe994b5470f81fe4671f.tar.bz2
spark-97dc2384ad4cb555200bbe994b5470f81fe4671f.zip
[SPARK-4593][SQL] Return null when denominator is 0
SELECT max(1/0) FROM src would return a very large number, which is obviously not right. For hive-0.12, hive would return `Infinity` for 1/0, while for hive-0.13.1, it is `NULL` for 1/0. I think it is better to keep our behavior with newer Hive version. This PR ensures that when the divider is 0, the result of expression should be NULL, same with hive-0.13.1 Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #3443 from adrian-wang/div and squashes the following commits: 2e98677 [Daoyuan Wang] fix code gen for divide 0 85c28ba [Daoyuan Wang] temp 36236a5 [Daoyuan Wang] add test cases 6f5716f [Daoyuan Wang] fix comments cee92bd [Daoyuan Wang] avoid evaluation 2 times 22ecd9a [Daoyuan Wang] fix style cf28c58 [Daoyuan Wang] divide fix 2dfe50f [Daoyuan Wang] return null when divider is 0 of Double type (cherry picked from commit f6df609dcc4f4a18c0f1c74b1ae0800cf09fa7ae) Signed-off-by: Michael Armbrust <michael@databricks.com>
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala19
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala15
4 files changed, 83 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 39b120e8de..bc45881e42 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -154,6 +154,25 @@ abstract class Expression extends TreeNode[Expression] {
}
/**
+ * Evaluation helper function for 1 Fractional children expression.
+ * if the expression result is null, the evaluation result should be null.
+ */
+ @inline
+ protected final def f1(i: Row, e1: Expression, f: ((Fractional[Any], Any) => Any)): Any = {
+ val evalE1 = e1.eval(i: Row)
+ if(evalE1 == null) {
+ null
+ } else {
+ e1.dataType match {
+ case ft: FractionalType =>
+ f.asInstanceOf[(Fractional[ft.JvmType], ft.JvmType) => ft.JvmType](
+ ft.fractional, evalE1.asInstanceOf[ft.JvmType])
+ case other => sys.error(s"Type $other does not support fractional operations")
+ }
+ }
+ }
+
+ /**
* Evaluation helper function for 2 Integral children expressions. Those expressions are
* supposed to be in the same data type, and also the return type.
* Either one of the expressions result is null, the evaluation result should be null.
@@ -190,6 +209,28 @@ abstract class Expression extends TreeNode[Expression] {
}
/**
+ * Evaluation helper function for 1 Integral children expression.
+ * if the expression result is null, the evaluation result should be null.
+ */
+ @inline
+ protected final def i1(i: Row, e1: Expression, f: ((Integral[Any], Any) => Any)): Any = {
+ val evalE1 = e1.eval(i)
+ if(evalE1 == null) {
+ null
+ } else {
+ e1.dataType match {
+ case i: IntegralType =>
+ f.asInstanceOf[(Integral[i.JvmType], i.JvmType) => i.JvmType](
+ i.integral, evalE1.asInstanceOf[i.JvmType])
+ case i: FractionalType =>
+ f.asInstanceOf[(Integral[i.JvmType], i.JvmType) => i.JvmType](
+ i.asIntegral, evalE1.asInstanceOf[i.JvmType])
+ case other => sys.error(s"Type $other does not support numeric operations")
+ }
+ }
+ }
+
+ /**
* Evaluation helper function for 2 Comparable children expressions. Those expressions are
* supposed to be in the same data type, and the return type should be Integer:
* Negative value: 1st argument less than 2nd argument
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 900b7586ad..7ec18b8419 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -105,11 +105,16 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "/"
- override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType]
+ override def nullable = true
- override def eval(input: Row): Any = dataType match {
- case _: FractionalType => f2(input, left, right, _.div(_, _))
- case _: IntegralType => i2(input, left , right, _.quot(_, _))
+ override def eval(input: Row): Any = {
+ val evalE2 = right.eval(input)
+ dataType match {
+ case _ if evalE2 == null => null
+ case _ if evalE2 == 0 => null
+ case ft: FractionalType => f1(input, left, _.div(_, evalE2.asInstanceOf[ft.JvmType]))
+ case it: IntegralType => i1(input, left, _.quot(_, evalE2.asInstanceOf[it.JvmType]))
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 67f8d411b6..ab71e15e1f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -359,7 +359,24 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" }
case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" }
case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" }
- case Divide(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 / $eval2" }
+ case Divide(e1, e2) =>
+ val eval1 = expressionEvaluator(e1)
+ val eval2 = expressionEvaluator(e2)
+
+ eval1.code ++ eval2.code ++
+ q"""
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(e1.dataType)} = 0
+
+ if (${eval1.nullTerm} || ${eval2.nullTerm} ) {
+ $nullTerm = true
+ } else if (${eval2.primitiveTerm} == 0)
+ $nullTerm = true
+ else {
+ $nullTerm = false
+ $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm}
+ }
+ """.children
case IsNotNull(e) =>
val eval = expressionEvaluator(e)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 3f5b9f698f..25f5642488 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -149,6 +149,21 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true)
}
+ test("Divide") {
+ checkEvaluation(Divide(Literal(2), Literal(1)), 2)
+ checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5)
+ checkEvaluation(Divide(Literal(1), Literal(2)), 0)
+ checkEvaluation(Divide(Literal(1), Literal(0)), null)
+ checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null)
+ checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null)
+ checkEvaluation(Divide(Literal(0), Literal(null, IntegerType)), null)
+ checkEvaluation(Divide(Literal(1), Literal(null, IntegerType)), null)
+ checkEvaluation(Divide(Literal(null, IntegerType), Literal(0)), null)
+ checkEvaluation(Divide(Literal(null, DoubleType), Literal(0.0)), null)
+ checkEvaluation(Divide(Literal(null, IntegerType), Literal(1)), null)
+ checkEvaluation(Divide(Literal(null, IntegerType), Literal(null, IntegerType)), null)
+ }
+
test("INSET") {
val hS = HashSet[Any]() + 1 + 2
val nS = HashSet[Any]() + 1 + 2 + null