diff options
Diffstat (limited to 'sql')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala | 14 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala | 10 |
2 files changed, 20 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 5a169488c9..f5bd068d60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -145,6 +145,14 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } + def toLimitedBigDecimal: BigDecimal = { + if (decimalVal.ne(null)) { + decimalVal + } else { + BigDecimal(longVal, _scale) + } + } + def toJavaBigDecimal: java.math.BigDecimal = toBigDecimal.underlying() def toUnscaledLong: Long = { @@ -269,9 +277,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (that.isZero) { null } else { - // To avoid non-terminating decimal expansion problem, we turn to Java BigDecimal's divide - // with specified ROUNDING_MODE. - Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, ROUNDING_MODE.id)) + // To avoid non-terminating decimal expansion problem, we get scala's BigDecimal with limited + // precision and scala. + Decimal(toLimitedBigDecimal / that.toLimitedBigDecimal) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index 5f312964e5..030bb6d21b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -170,6 +170,14 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("fix non-terminating decimal expansion problem") { val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3) - assert(decimal.toString === "0.333") + // The difference between decimal should not be more than 0.001. + assert(decimal.toDouble - 0.333 < 0.001) + } + + test("fix loss of precision/scale when doing division operation") { + val a = Decimal(2) / Decimal(3) + assert(a.toDouble < 1.0 && a.toDouble > 0.6) + val b = Decimal(1) / Decimal(8) + assert(b.toDouble === 0.125) } } |