aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-10-01 11:48:15 -0700
committerDavies Liu <davies.liu@gmail.com>2015-10-01 11:48:15 -0700
commit4d8c7c6d1c973f07e210e27936b063b5a763e9a3 (patch)
tree8266040f961d60b64327eb8d74c598f1b9546e44 /sql
parent9b3e7768a27d51ddd4711c4a68a428a6875bd6d7 (diff)
downloadspark-4d8c7c6d1c973f07e210e27936b063b5a763e9a3.tar.gz
spark-4d8c7c6d1c973f07e210e27936b063b5a763e9a3.tar.bz2
spark-4d8c7c6d1c973f07e210e27936b063b5a763e9a3.zip
[SPARK-10865] [SPARK-10866] [SQL] Fix bug of ceil/floor, which should returns long instead of the Double type
Floor & Ceiling function should returns Long type, rather than Double. Verified with MySQL & Hive. Author: Cheng Hao <hao.cheng@intel.com> Closes #8933 from chenghao-intel/ceiling.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala24
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala14
3 files changed, 31 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 15ceb9193a..39de0e8f44 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -52,7 +52,7 @@ abstract class LeafMathExpression(c: Double, name: String)
* @param f The math function.
* @param name The short name of the function
*/
-abstract class UnaryMathExpression(f: Double => Double, name: String)
+abstract class UnaryMathExpression(val f: Double => Double, name: String)
extends UnaryExpression with Serializable with ImplicitCastInputTypes {
override def inputTypes: Seq[DataType] = Seq(DoubleType)
@@ -152,7 +152,16 @@ case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN"
case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT")
-case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL")
+case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") {
+ override def dataType: DataType = LongType
+ protected override def nullSafeEval(input: Any): Any = {
+ f(input.asInstanceOf[Double]).toLong
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
+ }
+}
case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS")
@@ -195,7 +204,16 @@ case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP")
case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1")
-case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR")
+case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") {
+ override def dataType: DataType = LongType
+ protected override def nullSafeEval(input: Any): Any = {
+ f(input.asInstanceOf[Double]).toLong
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
+ }
+}
object Factorial {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 90c59f240b..1b2a9163a3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -244,12 +244,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("ceil") {
- testUnary(Ceil, math.ceil)
+ testUnary(Ceil, (d: Double) => math.ceil(d).toLong)
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType)
}
test("floor") {
- testUnary(Floor, math.floor)
+ testUnary(Floor, (d: Double) => math.floor(d).toLong)
checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 30289c3c1d..58f982c2bc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -37,9 +37,11 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
private lazy val nullDoubles =
Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF()
- private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T](
+ private def testOneToOneMathFunction[
+ @specialized(Int, Long, Float, Double) T,
+ @specialized(Int, Long, Float, Double) U](
c: Column => Column,
- f: T => T): Unit = {
+ f: T => U): Unit = {
checkAnswer(
doubleData.select(c('a)),
(1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T])))
@@ -165,10 +167,10 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
}
test("ceil and ceiling") {
- testOneToOneMathFunction(ceil, math.ceil)
+ testOneToOneMathFunction(ceil, (d: Double) => math.ceil(d).toLong)
checkAnswer(
sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
- Row(0.0, 1.0, 2.0))
+ Row(0L, 1L, 2L))
}
test("conv") {
@@ -184,7 +186,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
}
test("floor") {
- testOneToOneMathFunction(floor, math.floor)
+ testOneToOneMathFunction(floor, (d: Double) => math.floor(d).toLong)
}
test("factorial") {
@@ -228,7 +230,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
}
test("signum / sign") {
- testOneToOneMathFunction[Double](signum, math.signum)
+ testOneToOneMathFunction[Double, Double](signum, math.signum)
checkAnswer(
sql("SELECT sign(10), signum(-11)"),