aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala1
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala29
4 files changed, 29 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index cfde3bfbec..33ac1fdab4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -84,10 +84,10 @@ object RowEncoder {
"fromJavaDate",
inputObject :: Nil)
- case _: DecimalType =>
+ case d: DecimalType =>
StaticInvoke(
Decimal.getClass,
- DecimalType.SYSTEM_DEFAULT,
+ d,
"fromDecimal",
inputObject :: Nil)
@@ -162,7 +162,7 @@ object RowEncoder {
* `org.apache.spark.sql.types.Decimal`.
*/
private def externalDataTypeForInput(dt: DataType): DataType = dt match {
- // In order to support both Decimal and java BigDecimal in external row, we make this
+ // In order to support both Decimal and java/scala BigDecimal in external row, we make this
// as java.lang.Object.
case _: DecimalType => ObjectType(classOf[java.lang.Object])
case _ => externalDataTypeFor(dt)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 6f4ec6b701..2f7422b742 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -386,6 +386,7 @@ object Decimal {
def fromDecimal(value: Any): Decimal = {
value match {
case j: java.math.BigDecimal => apply(j)
+ case d: BigDecimal => apply(d)
case d: Decimal => d
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index c3b20e2cc0..177b1390b2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -108,7 +108,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
encodeDecodeTest(new java.lang.Double(-3.7), "boxed double")
encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal")
- // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
+ encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal")
@@ -336,6 +336,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]])
case (b1: Array[_], b2: Array[_]) =>
Arrays.equals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]])
+ case (left: Comparable[Any], right: Comparable[Any]) => left.compareTo(right) == 0
case _ => input == convertedBack
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 98be3b053d..4800e2e26e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -143,21 +143,38 @@ class RowEncoderSuite extends SparkFunSuite {
assert(input.getStruct(0) == convertedBack.getStruct(0))
}
- test("encode/decode Decimal") {
+ test("encode/decode decimal type") {
val schema = new StructType()
.add("int", IntegerType)
.add("string", StringType)
.add("double", DoubleType)
- .add("decimal", DecimalType.SYSTEM_DEFAULT)
+ .add("java_decimal", DecimalType.SYSTEM_DEFAULT)
+ .add("scala_decimal", DecimalType.SYSTEM_DEFAULT)
+ .add("catalyst_decimal", DecimalType.SYSTEM_DEFAULT)
val encoder = RowEncoder(schema)
- val input: Row = Row(100, "test", 0.123, Decimal(1234.5678))
+ val javaDecimal = new java.math.BigDecimal("1234.5678")
+ val scalaDecimal = BigDecimal("1234.5678")
+ val catalystDecimal = Decimal("1234.5678")
+
+ val input = Row(100, "test", 0.123, javaDecimal, scalaDecimal, catalystDecimal)
val row = encoder.toRow(input)
val convertedBack = encoder.fromRow(row)
- // Decimal inside external row will be converted back to Java BigDecimal when decoding.
- assert(input.get(3).asInstanceOf[Decimal].toJavaBigDecimal
- .compareTo(convertedBack.getDecimal(3)) == 0)
+ // Decimal will be converted back to Java BigDecimal when decoding.
+ assert(convertedBack.getDecimal(3).compareTo(javaDecimal) == 0)
+ assert(convertedBack.getDecimal(4).compareTo(scalaDecimal.bigDecimal) == 0)
+ assert(convertedBack.getDecimal(5).compareTo(catalystDecimal.toJavaBigDecimal) == 0)
+ }
+
+ test("RowEncoder should preserve decimal precision and scale") {
+ val schema = new StructType().add("decimal", DecimalType(10, 5), false)
+ val encoder = RowEncoder(schema)
+ val decimal = Decimal("67123.45")
+ val input = Row(decimal)
+ val row = encoder.toRow(input)
+
+ assert(row.toSeq(schema).head == decimal)
}
test("RowEncoder should preserve schema nullability") {