aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-22 16:45:20 -0700
committerDavies Liu <davies.liu@gmail.com>2016-03-22 16:45:20 -0700
commit4700adb98e4a37c2b0ef7123eca8a9a03bbdbe78 (patch)
tree214f9742d136c34836d9567f48c86a8b368260b8
parentd16710b4c986f0eaf28552ce0e2db33d8c9343b8 (diff)
downloadspark-4700adb98e4a37c2b0ef7123eca8a9a03bbdbe78.tar.gz
spark-4700adb98e4a37c2b0ef7123eca8a9a03bbdbe78.tar.bz2
spark-4700adb98e4a37c2b0ef7123eca8a9a03bbdbe78.zip
[SPARK-13806] [SQL] fix rounding mode of negative float/double
## What changes were proposed in this pull request? Round() in database usually round the number up (away from zero), it's different than Math.round() in Java. For example: ``` scala> java.lang.Math.round(-3.5) res3: Long = -3 ``` In Database, we should return -4.0 in this cases. This PR remove the buggy special case for scale=0. ## How was this patch tested? Add tests for negative values with tie. Author: Davies Liu <davies@databricks.com> Closes #11894 from davies/fix_round.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala48
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala4
2 files changed, 19 insertions, 33 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 12fcc40376..e3d1bc127d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -748,7 +748,7 @@ case class Round(child: Expression, scale: Expression)
if (f.isNaN || f.isInfinite) {
f
} else {
- BigDecimal(f).setScale(_scale, HALF_UP).toFloat
+ BigDecimal(f.toDouble).setScale(_scale, HALF_UP).toFloat
}
case DoubleType =>
val d = input1.asInstanceOf[Double]
@@ -804,39 +804,21 @@ case class Round(child: Expression, scale: Expression)
s"${ev.value} = ${ce.value};"
}
case FloatType => // if child eval to NaN or Infinity, just return it.
- if (_scale == 0) {
- s"""
- if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})) {
- ${ev.value} = ${ce.value};
- } else {
- ${ev.value} = Math.round(${ce.value});
- }"""
- } else {
- s"""
- if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})) {
- ${ev.value} = ${ce.value};
- } else {
- ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
- setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
- }"""
- }
+ s"""
+ if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})) {
+ ${ev.value} = ${ce.value};
+ } else {
+ ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
+ setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
+ }"""
case DoubleType => // if child eval to NaN or Infinity, just return it.
- if (_scale == 0) {
- s"""
- if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})) {
- ${ev.value} = ${ce.value};
- } else {
- ${ev.value} = Math.round(${ce.value});
- }"""
- } else {
- s"""
- if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})) {
- ${ev.value} = ${ce.value};
- } else {
- ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
- setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
- }"""
- }
+ s"""
+ if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})) {
+ ${ev.value} = ${ce.value};
+ } else {
+ ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
+ setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
+ }"""
}
if (scaleV == null) { // if scale is null, no need to eval its child at all
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index bd674dadd0..27195d3458 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -553,5 +553,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Round(Literal.create(null, dataType),
Literal.create(null, IntegerType)), null)
}
+
+ checkEvaluation(Round(-3.5, 0), -4.0)
+ checkEvaluation(Round(-0.35, 1), -0.4)
+ checkEvaluation(Round(-35, -1), -40)
}
}