From ddc7ba31cb1062acb182293b2698b1b20ea56a46 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 16 Dec 2014 21:19:57 -0800 Subject: [SPARK-4720][SQL] Remainder should also return null if the divider is 0. This is a follow-up of SPARK-4593 (#3443). Author: Takuya UESHIN Closes #3581 from ueshin/issues/SPARK-4720 and squashes the following commits: c3959d4 [Takuya UESHIN] Make Remainder return null if the divider is 0. --- .../spark/sql/catalyst/expressions/arithmetic.scala | 11 +++++++++-- .../catalyst/expressions/codegen/CodeGenerator.scala | 19 +++++++++++++++++++ .../expressions/ExpressionEvaluationSuite.scala | 15 +++++++++++++++ 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 61c26c50a6..79a742ad4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -122,9 +122,16 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "%" - override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType] + override def nullable = true - override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _)) + override def eval(input: Row): Any = { + val evalE2 = right.eval(input) + dataType match { + case _ if evalE2 == null => null + case _ if evalE2 == 0 => null + case nt: NumericType => i1(input, left, _.rem(_, evalE2.asInstanceOf[nt.JvmType])) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 48727d5e90..90c81b2631 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -379,6 +379,25 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } """.children + case Remainder(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(e1.dataType)} = 0 + + if (${eval1.nullTerm} || ${eval2.nullTerm} ) { + $nullTerm = true + } else if (${eval2.primitiveTerm} == 0) + $nullTerm = true + else { + $nullTerm = false + $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm} + } + """.children + case IsNotNull(e) => val eval = expressionEvaluator(e) q""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index b030483223..1e371db315 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -179,6 +179,21 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Divide(Literal(null, IntegerType), Literal(null, IntegerType)), null) } + test("Remainder") { + checkEvaluation(Remainder(Literal(2), Literal(1)), 0) + checkEvaluation(Remainder(Literal(1.0), Literal(2.0)), 1.0) + checkEvaluation(Remainder(Literal(1), Literal(2)), 1) + checkEvaluation(Remainder(Literal(1), Literal(0)), null) + checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) + checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) + checkEvaluation(Remainder(Literal(0), Literal(null, IntegerType)), null) + checkEvaluation(Remainder(Literal(1), Literal(null, IntegerType)), null) + checkEvaluation(Remainder(Literal(null, IntegerType), Literal(0)), null) + checkEvaluation(Remainder(Literal(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Remainder(Literal(null, IntegerType), Literal(1)), null) + checkEvaluation(Remainder(Literal(null, IntegerType), Literal(null, IntegerType)), null) + } + test("INSET") { val hS = HashSet[Any]() + 1 + 2 val nS = HashSet[Any]() + 1 + 2 + null -- cgit v1.2.3