diff options
author | William Benton <willb@redhat.com> | 2014-08-29 15:26:59 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-08-29 15:26:59 -0700 |
commit | 2f1519defaba4f3c7d536669f909bfd9e13e4069 (patch) | |
tree | dc19633311874e2ac5a5df8a94165981e69d44bc /sql/catalyst | |
parent | 53aa8316e88980c6f46d3b9fc90d935a4738a370 (diff) | |
download | spark-2f1519defaba4f3c7d536669f909bfd9e13e4069.tar.gz spark-2f1519defaba4f3c7d536669f909bfd9e13e4069.tar.bz2 spark-2f1519defaba4f3c7d536669f909bfd9e13e4069.zip |
SPARK-2813: [SQL] Implement SQRT() directly in Spark SQL
This PR adds a native implementation for SQL SQRT() and thus avoids delegating this function to Hive.
Author: William Benton <willb@redhat.com>
Closes #1750 from willb/spark-2813 and squashes the following commits:
22c8a79 [William Benton] Fixed missed newline from rebase
d673861 [William Benton] Added string coercions for SQRT and associated test case
e125df4 [William Benton] Added ExpressionEvaluationSuite test cases for SQRT
7b84bcd [William Benton] SQL SQRT now properly returns NULL for NULL inputs
8256971 [William Benton] added SQRT test to SqlQuerySuite
504d2e5 [William Benton] Added native SQRT implementation
Diffstat (limited to 'sql/catalyst')
4 files changed, 30 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 2c73a80f64..4f166c06b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -122,6 +122,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val EXCEPT = Keyword("EXCEPT") protected val SUBSTR = Keyword("SUBSTR") protected val SUBSTRING = Keyword("SUBSTRING") + protected val SQRT = Keyword("SQRT") // Use reflection to find the reserved words defined in this class. protected val reservedWords = @@ -323,6 +324,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l) } | + SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ { case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 15eb5982a4..ecfcd62d20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -227,6 +227,8 @@ trait HiveTypeCoercion { Sum(Cast(e, DoubleType)) case Average(e) if e.dataType == StringType => Average(Cast(e, DoubleType)) + case Sqrt(e) if e.dataType == StringType => + Sqrt(Cast(e, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index aae86a3628..56f042891a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -33,6 +33,19 @@ case class UnaryMinus(child: Expression) extends UnaryExpression { } } +case class Sqrt(child: Expression) extends UnaryExpression { + type EvaluatedType = Any + + def dataType = child.dataType + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"SQRT($child)" + + override def eval(input: Row): Any = { + n1(child, input, ((na,a) => math.sqrt(na.toDouble(a)))) + } +} + abstract class BinaryArithmetic extends BinaryExpression { self: Product => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index f1df817c41..b961346dfc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -577,4 +577,17 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(s.substring(0, 2), "ex", row) checkEvaluation(s.substring(0), "example", row) } + + test("SQRT") { + val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) + val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) + val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble))) + val d = 'a.double.at(0) + + for ((row, expected) <- rowSequence zip expectedResults) { + checkEvaluation(Sqrt(d), expected, row) + } + + checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null))) + } } |