From 902e4d54acbc3c88163a5c6447aff68ed57475c1 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 17 Dec 2014 12:51:27 -0800 Subject: [SPARK-4755] [SQL] sqrt(negative value) should return null Author: Daoyuan Wang Closes #3616 from adrian-wang/sqrt and squashes the following commits: d877439 [Daoyuan Wang] fix NULLTYPE 3effa2c [Daoyuan Wang] sqrt(negative value) should return null --- .../spark/sql/catalyst/expressions/arithmetic.scala | 15 +++++++++++++-- .../catalyst/expressions/ExpressionEvaluationSuite.scala | 2 ++ 2 files changed, 15 insertions(+), 2 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 79a742ad4b..168a963e29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -38,11 +38,22 @@ case class Sqrt(child: Expression) extends UnaryExpression { def dataType = DoubleType override def foldable = child.foldable - def nullable = child.nullable + def nullable = true override def toString = s"SQRT($child)" override def eval(input: Row): Any = { - n1(child, input, (na,a) => math.sqrt(na.toDouble(a))) + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + child.dataType match { + case n: NumericType => + val value = n.numeric.toDouble(evalE.asInstanceOf[n.JvmType]) + if (value < 0) null + else math.sqrt(value) + case other => sys.error(s"Type $other does not support non-negative numeric operations") + } + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 1e371db315..4ba7d87ba8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -1037,6 +1037,8 @@ class ExpressionEvaluationSuite extends FunSuite { } checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Sqrt(-1), null, EmptyRow) + checkEvaluation(Sqrt(-1.5), null, EmptyRow) } test("Bitwise operations") { -- cgit v1.2.3