aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala19
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala15
4 files changed, 83 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 39b120e8de..bc45881e42 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -154,6 +154,25 @@ abstract class Expression extends TreeNode[Expression] {
}
/**
+ * Evaluation helper function for 1 Fractional children expression.
+ * if the expression result is null, the evaluation result should be null.
+ */
+ @inline
+ protected final def f1(i: Row, e1: Expression, f: ((Fractional[Any], Any) => Any)): Any = {
+ val evalE1 = e1.eval(i: Row)
+ if(evalE1 == null) {
+ null
+ } else {
+ e1.dataType match {
+ case ft: FractionalType =>
+ f.asInstanceOf[(Fractional[ft.JvmType], ft.JvmType) => ft.JvmType](
+ ft.fractional, evalE1.asInstanceOf[ft.JvmType])
+ case other => sys.error(s"Type $other does not support fractional operations")
+ }
+ }
+ }
+
+ /**
* Evaluation helper function for 2 Integral children expressions. Those expressions are
* supposed to be in the same data type, and also the return type.
* Either one of the expressions result is null, the evaluation result should be null.
@@ -190,6 +209,28 @@ abstract class Expression extends TreeNode[Expression] {
}
/**
+ * Evaluation helper function for 1 Integral children expression.
+ * if the expression result is null, the evaluation result should be null.
+ */
+ @inline
+ protected final def i1(i: Row, e1: Expression, f: ((Integral[Any], Any) => Any)): Any = {
+ val evalE1 = e1.eval(i)
+ if(evalE1 == null) {
+ null
+ } else {
+ e1.dataType match {
+ case i: IntegralType =>
+ f.asInstanceOf[(Integral[i.JvmType], i.JvmType) => i.JvmType](
+ i.integral, evalE1.asInstanceOf[i.JvmType])
+ case i: FractionalType =>
+ f.asInstanceOf[(Integral[i.JvmType], i.JvmType) => i.JvmType](
+ i.asIntegral, evalE1.asInstanceOf[i.JvmType])
+ case other => sys.error(s"Type $other does not support numeric operations")
+ }
+ }
+ }
+
+ /**
* Evaluation helper function for 2 Comparable children expressions. Those expressions are
* supposed to be in the same data type, and the return type should be Integer:
* Negative value: 1st argument less than 2nd argument
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 900b7586ad..7ec18b8419 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -105,11 +105,16 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "/"
- override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType]
+ override def nullable = true
- override def eval(input: Row): Any = dataType match {
- case _: FractionalType => f2(input, left, right, _.div(_, _))
- case _: IntegralType => i2(input, left , right, _.quot(_, _))
+ override def eval(input: Row): Any = {
+ val evalE2 = right.eval(input)
+ dataType match {
+ case _ if evalE2 == null => null
+ case _ if evalE2 == 0 => null
+ case ft: FractionalType => f1(input, left, _.div(_, evalE2.asInstanceOf[ft.JvmType]))
+ case it: IntegralType => i1(input, left, _.quot(_, evalE2.asInstanceOf[it.JvmType]))
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 67f8d411b6..ab71e15e1f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -359,7 +359,24 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" }
case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" }
case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" }
- case Divide(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 / $eval2" }
+ case Divide(e1, e2) =>
+ val eval1 = expressionEvaluator(e1)
+ val eval2 = expressionEvaluator(e2)
+
+ eval1.code ++ eval2.code ++
+ q"""
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(e1.dataType)} = 0
+
+ if (${eval1.nullTerm} || ${eval2.nullTerm} ) {
+ $nullTerm = true
+ } else if (${eval2.primitiveTerm} == 0)
+ $nullTerm = true
+ else {
+ $nullTerm = false
+ $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm}
+ }
+ """.children
case IsNotNull(e) =>
val eval = expressionEvaluator(e)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 3f5b9f698f..25f5642488 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -149,6 +149,21 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true)
}
+ test("Divide") {
+ checkEvaluation(Divide(Literal(2), Literal(1)), 2)
+ checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5)
+ checkEvaluation(Divide(Literal(1), Literal(2)), 0)
+ checkEvaluation(Divide(Literal(1), Literal(0)), null)
+ checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null)
+ checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null)
+ checkEvaluation(Divide(Literal(0), Literal(null, IntegerType)), null)
+ checkEvaluation(Divide(Literal(1), Literal(null, IntegerType)), null)
+ checkEvaluation(Divide(Literal(null, IntegerType), Literal(0)), null)
+ checkEvaluation(Divide(Literal(null, DoubleType), Literal(0.0)), null)
+ checkEvaluation(Divide(Literal(null, IntegerType), Literal(1)), null)
+ checkEvaluation(Divide(Literal(null, IntegerType), Literal(null, IntegerType)), null)
+ }
+
test("INSET") {
val hS = HashSet[Any]() + 1 + 2
val nS = HashSet[Any]() + 1 + 2 + null