diff options
author | Cheng Hao <hao.cheng@intel.com> | 2015-10-01 11:48:15 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-10-01 11:48:15 -0700 |
commit | 4d8c7c6d1c973f07e210e27936b063b5a763e9a3 (patch) | |
tree | 8266040f961d60b64327eb8d74c598f1b9546e44 /sql | |
parent | 9b3e7768a27d51ddd4711c4a68a428a6875bd6d7 (diff) | |
download | spark-4d8c7c6d1c973f07e210e27936b063b5a763e9a3.tar.gz spark-4d8c7c6d1c973f07e210e27936b063b5a763e9a3.tar.bz2 spark-4d8c7c6d1c973f07e210e27936b063b5a763e9a3.zip |
[SPARK-10865] [SPARK-10866] [SQL] Fix bug of ceil/floor, which should returns long instead of the Double type
Floor & Ceiling function should returns Long type, rather than Double.
Verified with MySQL & Hive.
Author: Cheng Hao <hao.cheng@intel.com>
Closes #8933 from chenghao-intel/ceiling.
Diffstat (limited to 'sql')
3 files changed, 31 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 15ceb9193a..39de0e8f44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -52,7 +52,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param f The math function. * @param name The short name of the function */ -abstract class UnaryMathExpression(f: Double => Double, name: String) +abstract class UnaryMathExpression(val f: Double => Double, name: String) extends UnaryExpression with Serializable with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(DoubleType) @@ -152,7 +152,16 @@ case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN" case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") -case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") +case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") { + override def dataType: DataType = LongType + protected override def nullSafeEval(input: Any): Any = { + f(input.asInstanceOf[Double]).toLong + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") + } +} case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") @@ -195,7 +204,16 @@ case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") -case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") +case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") { + override def dataType: DataType = LongType + protected override def nullSafeEval(input: Any): Any = { + f(input.asInstanceOf[Double]).toLong + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") + } +} object Factorial { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 90c59f240b..1b2a9163a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -244,12 +244,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("ceil") { - testUnary(Ceil, math.ceil) + testUnary(Ceil, (d: Double) => math.ceil(d).toLong) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) } test("floor") { - testUnary(Floor, math.floor) + testUnary(Floor, (d: Double) => math.floor(d).toLong) checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 30289c3c1d..58f982c2bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -37,9 +37,11 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { private lazy val nullDoubles = Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF() - private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + private def testOneToOneMathFunction[ + @specialized(Int, Long, Float, Double) T, + @specialized(Int, Long, Float, Double) U]( c: Column => Column, - f: T => T): Unit = { + f: T => U): Unit = { checkAnswer( doubleData.select(c('a)), (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) @@ -165,10 +167,10 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { } test("ceil and ceiling") { - testOneToOneMathFunction(ceil, math.ceil) + testOneToOneMathFunction(ceil, (d: Double) => math.ceil(d).toLong) checkAnswer( sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), - Row(0.0, 1.0, 2.0)) + Row(0L, 1L, 2L)) } test("conv") { @@ -184,7 +186,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { } test("floor") { - testOneToOneMathFunction(floor, math.floor) + testOneToOneMathFunction(floor, (d: Double) => math.floor(d).toLong) } test("factorial") { @@ -228,7 +230,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { } test("signum / sign") { - testOneToOneMathFunction[Double](signum, math.signum) + testOneToOneMathFunction[Double, Double](signum, math.signum) checkAnswer( sql("SELECT sign(10), signum(-11)"), |