aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-06-18 13:00:31 -0700
committerReynold Xin <rxin@databricks.com>2015-06-18 13:00:31 -0700
commit31641128b34d6f2aa7cb67324c24dd8b3ed84689 (patch)
tree782d1d60e86de06322fa81187db969f1d866adf5 /sql
parentddc5baf17d7b09623b91190ee7754a6c8f7b5d10 (diff)
downloadspark-31641128b34d6f2aa7cb67324c24dd8b3ed84689.tar.gz
spark-31641128b34d6f2aa7cb67324c24dd8b3ed84689.tar.bz2
spark-31641128b34d6f2aa7cb67324c24dd8b3ed84689.zip
[SPARK-8363][SQL] Move sqrt to math and extend UnaryMathExpression
JIRA: https://issues.apache.org/jira/browse/SPARK-8363 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #6823 from viirya/move_sqrt and squashes the following commits: 8977e11 [Liang-Chi Hsieh] Remove unnecessary old tests. d23e79e [Liang-Chi Hsieh] Explicitly indicate sqrt value sequence. 699f48b [Liang-Chi Hsieh] Use correct @since tag. 8dff6d1 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into move_sqrt bc2ed77 [Liang-Chi Hsieh] Remove/move arithmetic expression test and expression type checking test. Remove unnecessary Sqrt type rule. d38492f [Liang-Chi Hsieh] Now sqrt accepts boolean because type casting is handled by HiveTypeCoercion. 297cc90 [Liang-Chi Hsieh] Sqrt only accepts double input. ef4a21a [Liang-Chi Hsieh] Move sqrt to math.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala10
8 files changed, 31 insertions, 51 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 189451d0d9..8012b224eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -307,7 +307,6 @@ trait HiveTypeCoercion {
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
- case Sqrt(e @ StringType()) => Sqrt(Cast(e, DoubleType))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 167e460d5a..ace8427c8d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -67,38 +67,6 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
protected override def evalInternal(evalE: Any) = evalE
}
-case class Sqrt(child: Expression) extends UnaryArithmetic {
- override def dataType: DataType = DoubleType
- override def nullable: Boolean = true
- override def toString: String = s"SQRT($child)"
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "function sqrt")
-
- private lazy val numeric = TypeUtils.getNumeric(child.dataType)
-
- protected override def evalInternal(evalE: Any) = {
- val value = numeric.toDouble(evalE)
- if (value < 0) null
- else math.sqrt(value)
- }
-
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val eval = child.gen(ctx)
- eval.code + s"""
- boolean ${ev.isNull} = ${eval.isNull};
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- if (${eval.primitive} < 0.0) {
- ${ev.isNull} = true;
- } else {
- ${ev.primitive} = java.lang.Math.sqrt(${eval.primitive});
- }
- }
- """
- }
-}
-
/**
* A function that get the absolute value of the numeric value.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 67cb0b508c..3b83c6da0e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -193,6 +193,8 @@ case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN")
case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH")
+case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT")
+
case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 3f4843259e..4bbbbe6c7f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -142,19 +142,4 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1)
checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1)
}
-
- test("SQRT") {
- val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24))
- val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
- val rowSequence = inputSequence.map(l => create_row(l.toDouble))
- val d = 'a.double.at(0)
-
- for ((row, expected) <- rowSequence zip expectedResults) {
- checkEvaluation(Sqrt(d), expected, row)
- }
-
- checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
- checkEvaluation(Sqrt(-1), null, EmptyRow)
- checkEvaluation(Sqrt(-1.5), null, EmptyRow)
- }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala
index dcb3635c5c..49b1119897 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala
@@ -54,8 +54,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
test("check types for unary arithmetic") {
assertError(UnaryMinus('stringField), "operator - accepts numeric type")
- assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt
- assertError(Sqrt('booleanField), "function sqrt accepts numeric type")
assertError(Abs('stringField), "function abs accepts numeric type")
assertError(BitwiseNot('stringField), "operator ~ accepts integral type")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 0050ad3fe8..21e9b92b72 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.DoubleType
@@ -191,6 +192,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true)
}
+ test("sqrt") {
+ testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1))
+ testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNull = true)
+
+ checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
+ checkEvaluation(Sqrt(Literal(-1.0)), null, EmptyRow)
+ checkEvaluation(Sqrt(Literal(-1.5)), null, EmptyRow)
+ }
+
test("pow") {
testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index dff0932c45..d8a91bead7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -707,12 +707,20 @@ object functions {
/**
* Computes the square root of the specified float value.
*
- * @group normal_funcs
+ * @group math_funcs
* @since 1.3.0
*/
def sqrt(e: Column): Column = Sqrt(e.expr)
/**
+ * Computes the square root of the specified float value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def sqrt(colName: String): Column = sqrt(Column(colName))
+
+ /**
* Creates a new struct column. The input column must be a column in a [[DataFrame]], or
* a derived column expression that is named (i.e. aliased).
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 7c9c121b95..2768d7dfc8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -270,6 +270,16 @@ class MathExpressionsSuite extends QueryTest {
checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
}
+ test("sqrt") {
+ val df = Seq((1, 4)).toDF("a", "b")
+ checkAnswer(
+ df.select(sqrt("a"), sqrt("b")),
+ Row(1.0, 2.0))
+
+ checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
+ checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null))
+ }
+
test("negative") {
checkAnswer(
ctx.sql("SELECT negative(1), negative(0), negative(-1)"),