aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala10
2 files changed, 20 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 5a169488c9..f5bd068d60 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -145,6 +145,14 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
}
+ def toLimitedBigDecimal: BigDecimal = {
+ if (decimalVal.ne(null)) {
+ decimalVal
+ } else {
+ BigDecimal(longVal, _scale)
+ }
+ }
+
def toJavaBigDecimal: java.math.BigDecimal = toBigDecimal.underlying()
def toUnscaledLong: Long = {
@@ -269,9 +277,9 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (that.isZero) {
null
} else {
- // To avoid non-terminating decimal expansion problem, we turn to Java BigDecimal's divide
- // with specified ROUNDING_MODE.
- Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, ROUNDING_MODE.id))
+ // To avoid non-terminating decimal expansion problem, we get scala's BigDecimal with limited
+ // precision and scala.
+ Decimal(toLimitedBigDecimal / that.toLimitedBigDecimal)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
index 5f312964e5..030bb6d21b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
@@ -170,6 +170,14 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
test("fix non-terminating decimal expansion problem") {
val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3)
- assert(decimal.toString === "0.333")
+ // The difference between decimal should not be more than 0.001.
+ assert(decimal.toDouble - 0.333 < 0.001)
+ }
+
+ test("fix loss of precision/scale when doing division operation") {
+ val a = Decimal(2) / Decimal(3)
+ assert(a.toDouble < 1.0 && a.toDouble > 0.6)
+ val b = Decimal(1) / Decimal(8)
+ assert(b.toDouble === 0.125)
}
}