aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-04 23:12:49 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-04 23:12:49 -0700
commit781c8d71a0a6a86c84048a4f22cb3a7d035a5be2 (patch)
tree2f76317e9764bcbd5fd5811b8c6247ed5dfde997 /sql/catalyst
parentd34548587ab55bc2136c8f823b9e6ae96e1355a4 (diff)
downloadspark-781c8d71a0a6a86c84048a4f22cb3a7d035a5be2.tar.gz
spark-781c8d71a0a6a86c84048a4f22cb3a7d035a5be2.tar.bz2
spark-781c8d71a0a6a86c84048a4f22cb3a7d035a5be2.zip
[SPARK-9119] [SPARK-8359] [SQL] match Decimal.precision/scale with DecimalType
Let Decimal carry the correct precision and scale with DecimalType. cc rxin yhuai Author: Davies Liu <davies@databricks.com> Closes #7925 from davies/decimal_scale and squashes the following commits: e19701a [Davies Liu] some tweaks 57d78d2 [Davies Liu] fix tests 5d5bc69 [Davies Liu] match precision and scale with DecimalType
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala37
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala21
7 files changed, 73 insertions, 22 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 91449479fa..40159aaf14 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -417,6 +417,10 @@ trait Row extends Serializable {
if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
return false
}
+ case d1: java.math.BigDecimal if o2.isInstanceOf[java.math.BigDecimal] =>
+ if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) {
+ return false
+ }
case _ => if (o1 != o2) {
return false
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index c666864e43..8d0c64eae4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -317,18 +317,23 @@ object CatalystTypeConverters {
private class DecimalConverter(dataType: DecimalType)
extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
- override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match {
- case d: BigDecimal => Decimal(d)
- case d: JavaBigDecimal => Decimal(d)
- case d: Decimal => d
+ override def toCatalystImpl(scalaValue: Any): Decimal = {
+ val decimal = scalaValue match {
+ case d: BigDecimal => Decimal(d)
+ case d: JavaBigDecimal => Decimal(d)
+ case d: Decimal => d
+ }
+ if (decimal.changePrecision(dataType.precision, dataType.scale)) {
+ decimal
+ } else {
+ null
+ }
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal
override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal =
row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal
}
- private object BigDecimalConverter extends DecimalConverter(DecimalType.SYSTEM_DEFAULT)
-
private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] {
final override def toScala(catalystValue: Any): Any = catalystValue
final override def toCatalystImpl(scalaValue: T): Any = scalaValue
@@ -413,8 +418,8 @@ object CatalystTypeConverters {
case s: String => StringConverter.toCatalyst(s)
case d: Date => DateConverter.toCatalyst(d)
case t: Timestamp => TimestampConverter.toCatalyst(t)
- case d: BigDecimal => BigDecimalConverter.toCatalyst(d)
- case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d)
+ case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d)
+ case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d)
case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray)
case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*)
case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 422d423747..490f3dc07b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -442,8 +442,8 @@ object HiveTypeCoercion {
* Changes numeric values to booleans so that expressions like true = 1 can be evaluated.
*/
object BooleanEquality extends Rule[LogicalPlan] {
- private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1))
- private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal(0))
+ private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE)
+ private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO)
private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = {
CaseKeyWhen(numericExpr, Seq(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 88429bb84b..39f99700c8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -26,8 +26,6 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-import scala.collection.mutable
-
object Cast {
@@ -157,7 +155,7 @@ case class Cast(child: Expression, dataType: DataType)
case ByteType =>
buildCast[Byte](_, _ != 0)
case DecimalType() =>
- buildCast[Decimal](_, _ != Decimal(0))
+ buildCast[Decimal](_, _ != Decimal.ZERO)
case DoubleType =>
buildCast[Double](_, _ != 0)
case FloatType =>
@@ -311,7 +309,7 @@ case class Cast(child: Expression, dataType: DataType)
case _: NumberFormatException => null
})
case BooleanType =>
- buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target))
+ buildCast[Boolean](_, b => changePrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
case DateType =>
buildCast[Int](_, d => null) // date can't cast to decimal in Hive
case TimestampType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 0891b55494..5808e3f66d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -511,6 +511,6 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
private def pmod(a: Decimal, n: Decimal): Decimal = {
val r = a % n
- if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r
+ if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index c0155eeb45..624c3f3d7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.types
+import java.math.{RoundingMode, MathContext}
+
import org.apache.spark.annotation.DeveloperApi
/**
@@ -28,7 +30,7 @@ import org.apache.spark.annotation.DeveloperApi
* - Otherwise, the decimal value is longVal / (10 ** _scale)
*/
final class Decimal extends Ordered[Decimal] with Serializable {
- import org.apache.spark.sql.types.Decimal.{BIG_DEC_ZERO, MAX_LONG_DIGITS, POW_10, ROUNDING_MODE}
+ import org.apache.spark.sql.types.Decimal._
private var decimalVal: BigDecimal = null
private var longVal: Long = 0L
@@ -137,9 +139,9 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def toBigDecimal: BigDecimal = {
if (decimalVal.ne(null)) {
- decimalVal
+ decimalVal(MATH_CONTEXT)
} else {
- BigDecimal(longVal, _scale)
+ BigDecimal(longVal, _scale)(MATH_CONTEXT)
}
}
@@ -261,10 +263,23 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0
- def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal)
+ def + (that: Decimal): Decimal = {
+ if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) {
+ Decimal(longVal + that.longVal, Math.max(precision, that.precision), scale)
+ } else {
+ Decimal(toBigDecimal + that.toBigDecimal, precision, scale)
+ }
+ }
- def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal)
+ def - (that: Decimal): Decimal = {
+ if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) {
+ Decimal(longVal - that.longVal, Math.max(precision, that.precision), scale)
+ } else {
+ Decimal(toBigDecimal - that.toBigDecimal, precision, scale)
+ }
+ }
+ // HiveTypeCoercion will take care of the precision, scale of result
def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal)
def / (that: Decimal): Decimal =
@@ -277,13 +292,13 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def unary_- : Decimal = {
if (decimalVal.ne(null)) {
- Decimal(-decimalVal)
+ Decimal(-decimalVal, precision, scale)
} else {
Decimal(-longVal, precision, scale)
}
}
- def abs: Decimal = if (this.compare(Decimal(0)) < 0) this.unary_- else this
+ def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this
}
object Decimal {
@@ -296,6 +311,11 @@ object Decimal {
private val BIG_DEC_ZERO = BigDecimal(0)
+ private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP)
+
+ private[sql] val ZERO = Decimal(0)
+ private[sql] val ONE = Decimal(1)
+
def apply(value: Double): Decimal = new Decimal().set(value)
def apply(value: Long): Decimal = new Decimal().set(value)
@@ -309,6 +329,9 @@ object Decimal {
def apply(value: BigDecimal, precision: Int, scale: Int): Decimal =
new Decimal().set(value, precision, scale)
+ def apply(value: java.math.BigDecimal, precision: Int, scale: Int): Decimal =
+ new Decimal().set(value, precision, scale)
+
def apply(unscaled: Long, precision: Int, scale: Int): Decimal =
new Decimal().set(unscaled, precision, scale)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
index 1d297beb38..6921d15958 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
@@ -166,6 +166,27 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
assert(Decimal(100) % Decimal(0) === null)
}
+ // regression test for SPARK-8359
+ test("accurate precision after multiplication") {
+ val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal
+ assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249")
+ }
+
+ // regression test for SPARK-8677
+ test("fix non-terminating decimal expansion problem") {
+ val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3)
+ // The difference between decimal should not be more than 0.001.
+ assert(decimal.toDouble - 0.333 < 0.001)
+ }
+
+ // regression test for SPARK-8800
+ test("fix loss of precision/scale when doing division operation") {
+ val a = Decimal(2) / Decimal(3)
+ assert(a.toDouble < 1.0 && a.toDouble > 0.6)
+ val b = Decimal(1) / Decimal(8)
+ assert(b.toDouble === 0.125)
+ }
+
test("set/setOrNull") {
assert(new Decimal().set(10L, 10, 0).toUnscaledLong === 10L)
assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L)