aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWilliam Benton <willb@redhat.com>2014-08-29 15:26:59 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-29 15:26:59 -0700
commit2f1519defaba4f3c7d536669f909bfd9e13e4069 (patch)
treedc19633311874e2ac5a5df8a94165981e69d44bc /sql
parent53aa8316e88980c6f46d3b9fc90d935a4738a370 (diff)
downloadspark-2f1519defaba4f3c7d536669f909bfd9e13e4069.tar.gz
spark-2f1519defaba4f3c7d536669f909bfd9e13e4069.tar.bz2
spark-2f1519defaba4f3c7d536669f909bfd9e13e4069.zip
SPARK-2813: [SQL] Implement SQRT() directly in Spark SQL
This PR adds a native implementation for SQL SQRT() and thus avoids delegating this function to Hive. Author: William Benton <willb@redhat.com> Closes #1750 from willb/spark-2813 and squashes the following commits: 22c8a79 [William Benton] Fixed missed newline from rebase d673861 [William Benton] Added string coercions for SQRT and associated test case e125df4 [William Benton] Added ExpressionEvaluationSuite test cases for SQRT 7b84bcd [William Benton] SQL SQRT now properly returns NULL for NULL inputs 8256971 [William Benton] added SQRT test to SqlQuerySuite 504d2e5 [William Benton] Added native SQRT implementation
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala13
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala14
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala2
6 files changed, 46 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 2c73a80f64..4f166c06b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -122,6 +122,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val EXCEPT = Keyword("EXCEPT")
protected val SUBSTR = Keyword("SUBSTR")
protected val SUBSTRING = Keyword("SUBSTRING")
+ protected val SQRT = Keyword("SQRT")
// Use reflection to find the reserved words defined in this class.
protected val reservedWords =
@@ -323,6 +324,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
(SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l)
} |
+ SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } |
ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ {
case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 15eb5982a4..ecfcd62d20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -227,6 +227,8 @@ trait HiveTypeCoercion {
Sum(Cast(e, DoubleType))
case Average(e) if e.dataType == StringType =>
Average(Cast(e, DoubleType))
+ case Sqrt(e) if e.dataType == StringType =>
+ Sqrt(Cast(e, DoubleType))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index aae86a3628..56f042891a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -33,6 +33,19 @@ case class UnaryMinus(child: Expression) extends UnaryExpression {
}
}
+case class Sqrt(child: Expression) extends UnaryExpression {
+ type EvaluatedType = Any
+
+ def dataType = child.dataType
+ override def foldable = child.foldable
+ def nullable = child.nullable
+ override def toString = s"SQRT($child)"
+
+ override def eval(input: Row): Any = {
+ n1(child, input, ((na,a) => math.sqrt(na.toDouble(a))))
+ }
+}
+
abstract class BinaryArithmetic extends BinaryExpression {
self: Product =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index f1df817c41..b961346dfc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -577,4 +577,17 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(s.substring(0, 2), "ex", row)
checkEvaluation(s.substring(0), "example", row)
}
+
+ test("SQRT") {
+ val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24))
+ val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
+ val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble)))
+ val d = 'a.double.at(0)
+
+ for ((row, expected) <- rowSequence zip expectedResults) {
+ checkEvaluation(Sqrt(d), expected, row)
+ }
+
+ checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null)))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 9b2a36d33f..4047bc0672 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -34,6 +34,20 @@ class SQLQuerySuite extends QueryTest {
"test")
}
+ test("SQRT") {
+ checkAnswer(
+ sql("SELECT SQRT(key) FROM testData"),
+ (1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq
+ )
+ }
+
+ test("SQRT with automatic string casts") {
+ checkAnswer(
+ sql("SELECT SQRT(CAST(key AS STRING)) FROM testData"),
+ (1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq
+ )
+ }
+
test("SPARK-2407 Added Parser of SQL SUBSTR()") {
checkAnswer(
sql("SELECT substr(tableName, 1, 2) FROM tableName"),
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index fa3adfdf58..a4dd6be5f9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -889,6 +889,7 @@ private[hive] object HiveQl {
val WHEN = "(?i)WHEN".r
val CASE = "(?i)CASE".r
val SUBSTR = "(?i)SUBSTR(?:ING)?".r
+ val SQRT = "(?i)SQRT".r
protected def nodeToExpr(node: Node): Expression = node match {
/* Attribute References */
@@ -958,6 +959,7 @@ private[hive] object HiveQl {
case Token(DIV(), left :: right:: Nil) =>
Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType)
case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right))
+ case Token("TOK_FUNCTION", Token(SQRT(), Nil) :: arg :: Nil) => Sqrt(nodeToExpr(arg))
/* Comparisons */
case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right))