aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorDaoyuan Wang <daoyuan.wang@intel.com>2014-12-17 12:51:27 -0800
committerMichael Armbrust <michael@databricks.com>2014-12-17 12:51:27 -0800
commit902e4d54acbc3c88163a5c6447aff68ed57475c1 (patch)
tree5f84220c9d36d65c418f1a0b5a3adaf365395ecd /sql/catalyst
parent62771353767b5eecf2ec6c732cab07369d784df5 (diff)
downloadspark-902e4d54acbc3c88163a5c6447aff68ed57475c1.tar.gz
spark-902e4d54acbc3c88163a5c6447aff68ed57475c1.tar.bz2
spark-902e4d54acbc3c88163a5c6447aff68ed57475c1.zip
[SPARK-4755] [SQL] sqrt(negative value) should return null
Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #3616 from adrian-wang/sqrt and squashes the following commits: d877439 [Daoyuan Wang] fix NULLTYPE 3effa2c [Daoyuan Wang] sqrt(negative value) should return null
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala2
2 files changed, 15 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 79a742ad4b..168a963e29 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -38,11 +38,22 @@ case class Sqrt(child: Expression) extends UnaryExpression {
def dataType = DoubleType
override def foldable = child.foldable
- def nullable = child.nullable
+ def nullable = true
override def toString = s"SQRT($child)"
override def eval(input: Row): Any = {
- n1(child, input, (na,a) => math.sqrt(na.toDouble(a)))
+ val evalE = child.eval(input)
+ if (evalE == null) {
+ null
+ } else {
+ child.dataType match {
+ case n: NumericType =>
+ val value = n.numeric.toDouble(evalE.asInstanceOf[n.JvmType])
+ if (value < 0) null
+ else math.sqrt(value)
+ case other => sys.error(s"Type $other does not support non-negative numeric operations")
+ }
+ }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 1e371db315..4ba7d87ba8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -1037,6 +1037,8 @@ class ExpressionEvaluationSuite extends FunSuite {
}
checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null)))
+ checkEvaluation(Sqrt(-1), null, EmptyRow)
+ checkEvaluation(Sqrt(-1.5), null, EmptyRow)
}
test("Bitwise operations") {