aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorWilliam Benton <willb@redhat.com>2014-08-29 15:26:59 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-29 15:26:59 -0700
commit2f1519defaba4f3c7d536669f909bfd9e13e4069 (patch)
treedc19633311874e2ac5a5df8a94165981e69d44bc /sql/catalyst
parent53aa8316e88980c6f46d3b9fc90d935a4738a370 (diff)
downloadspark-2f1519defaba4f3c7d536669f909bfd9e13e4069.tar.gz
spark-2f1519defaba4f3c7d536669f909bfd9e13e4069.tar.bz2
spark-2f1519defaba4f3c7d536669f909bfd9e13e4069.zip
SPARK-2813: [SQL] Implement SQRT() directly in Spark SQL
This PR adds a native implementation for SQL SQRT() and thus avoids delegating this function to Hive. Author: William Benton <willb@redhat.com> Closes #1750 from willb/spark-2813 and squashes the following commits: 22c8a79 [William Benton] Fixed missed newline from rebase d673861 [William Benton] Added string coercions for SQRT and associated test case e125df4 [William Benton] Added ExpressionEvaluationSuite test cases for SQRT 7b84bcd [William Benton] SQL SQRT now properly returns NULL for NULL inputs 8256971 [William Benton] added SQRT test to SqlQuerySuite 504d2e5 [William Benton] Added native SQRT implementation
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala13
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala13
4 files changed, 30 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 2c73a80f64..4f166c06b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -122,6 +122,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val EXCEPT = Keyword("EXCEPT")
protected val SUBSTR = Keyword("SUBSTR")
protected val SUBSTRING = Keyword("SUBSTRING")
+ protected val SQRT = Keyword("SQRT")
// Use reflection to find the reserved words defined in this class.
protected val reservedWords =
@@ -323,6 +324,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
(SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l)
} |
+ SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } |
ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ {
case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 15eb5982a4..ecfcd62d20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -227,6 +227,8 @@ trait HiveTypeCoercion {
Sum(Cast(e, DoubleType))
case Average(e) if e.dataType == StringType =>
Average(Cast(e, DoubleType))
+ case Sqrt(e) if e.dataType == StringType =>
+ Sqrt(Cast(e, DoubleType))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index aae86a3628..56f042891a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -33,6 +33,19 @@ case class UnaryMinus(child: Expression) extends UnaryExpression {
}
}
+case class Sqrt(child: Expression) extends UnaryExpression {
+ type EvaluatedType = Any
+
+ def dataType = child.dataType
+ override def foldable = child.foldable
+ def nullable = child.nullable
+ override def toString = s"SQRT($child)"
+
+ override def eval(input: Row): Any = {
+ n1(child, input, ((na,a) => math.sqrt(na.toDouble(a))))
+ }
+}
+
abstract class BinaryArithmetic extends BinaryExpression {
self: Product =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index f1df817c41..b961346dfc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -577,4 +577,17 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(s.substring(0, 2), "ex", row)
checkEvaluation(s.substring(0), "example", row)
}
+
+ test("SQRT") {
+ val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24))
+ val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
+ val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble)))
+ val d = 'a.double.at(0)
+
+ for ((row, expected) <- rowSequence zip expectedResults) {
+ checkEvaluation(Sqrt(d), expected, row)
+ }
+
+ checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null)))
+ }
}