aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-09-21 21:02:30 -0700
committerReynold Xin <rxin@databricks.com>2016-09-21 21:02:30 -0700
commit8bde03bf9a0896ea59ceaa699df7700351a130fb (patch)
treef1f02561c47129d2ee6ce52ad32e129dc0af715c
parent3497ebe511fee67e66387e9e737c843a2939ce45 (diff)
downloadspark-8bde03bf9a0896ea59ceaa699df7700351a130fb.tar.gz
spark-8bde03bf9a0896ea59ceaa699df7700351a130fb.tar.bz2
spark-8bde03bf9a0896ea59ceaa699df7700351a130fb.zip
[SPARK-17494][SQL] changePrecision() on compact decimal should respect rounding mode
## What changes were proposed in this pull request? Floor()/Ceil() of decimal is implemented using changePrecision() by passing a rounding mode, but the rounding mode is not respected when the decimal is in compact mode (could fit within a Long). This Update the changePrecision() to respect rounding mode, which could be ROUND_FLOOR, ROUND_CEIL, ROUND_HALF_UP, ROUND_HALF_EVEN. ## How was this patch tested? Added regression tests. Author: Davies Liu <davies@databricks.com> Closes #15154 from davies/decimal_round.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala28
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala15
2 files changed, 39 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index cc8175c0a3..7085905287 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -242,10 +242,30 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (scale < _scale) {
// Easier case: we just need to divide our scale down
val diff = _scale - scale
- val droppedDigits = longVal % POW_10(diff)
- longVal /= POW_10(diff)
- if (math.abs(droppedDigits) * 2 >= POW_10(diff)) {
- longVal += (if (longVal < 0) -1L else 1L)
+ val pow10diff = POW_10(diff)
+ // % and / always round to 0
+ val droppedDigits = longVal % pow10diff
+ longVal /= pow10diff
+ roundMode match {
+ case ROUND_FLOOR =>
+ if (droppedDigits < 0) {
+ longVal += -1L
+ }
+ case ROUND_CEILING =>
+ if (droppedDigits > 0) {
+ longVal += 1L
+ }
+ case ROUND_HALF_UP =>
+ if (math.abs(droppedDigits) * 2 >= pow10diff) {
+ longVal += (if (droppedDigits < 0) -1L else 1L)
+ }
+ case ROUND_HALF_EVEN =>
+ val doubled = math.abs(droppedDigits) * 2
+ if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) {
+ longVal += (if (droppedDigits < 0) -1L else 1L)
+ }
+ case _ =>
+ sys.error(s"Not supported rounding mode: $roundMode")
}
} else if (scale > _scale) {
// We might be able to multiply longVal by a power of 10 and not overflow, but if not,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index a10c0e39eb..52d0692524 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.types
import org.scalatest.PrivateMethodTester
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types.Decimal._
class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
/** Check that a Decimal has the given string representation, precision and scale */
@@ -191,4 +192,18 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L)
assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue)
}
+
+ test("changePrecision() on compact decimal should respect rounding mode") {
+ Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode =>
+ Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n =>
+ Seq("", "-").foreach { sign =>
+ val bd = BigDecimal(sign + n)
+ val unscaled = (bd * 10).toLongExact
+ val d = Decimal(unscaled, 8, 1)
+ assert(d.changePrecision(10, 0, mode))
+ assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
+ }
+ }
+ }
+ }
}