aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-04 23:12:49 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-04 23:12:49 -0700
commit781c8d71a0a6a86c84048a4f22cb3a7d035a5be2 (patch)
tree2f76317e9764bcbd5fd5811b8c6247ed5dfde997 /sql
parentd34548587ab55bc2136c8f823b9e6ae96e1355a4 (diff)
downloadspark-781c8d71a0a6a86c84048a4f22cb3a7d035a5be2.tar.gz
spark-781c8d71a0a6a86c84048a4f22cb3a7d035a5be2.tar.bz2
spark-781c8d71a0a6a86c84048a4f22cb3a7d035a5be2.zip
[SPARK-9119] [SPARK-8359] [SQL] match Decimal.precision/scale with DecimalType
Let Decimal carry the correct precision and scale with DecimalType. cc rxin yhuai Author: Davies Liu <davies@databricks.com> Closes #7925 from davies/decimal_scale and squashes the following commits: e19701a [Davies Liu] some tweaks 57d78d2 [Davies Liu] fix tests 5d5bc69 [Davies Liu] match precision and scale with DecimalType
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala37
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala5
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala26
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala13
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala2
16 files changed, 122 insertions, 50 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 91449479fa..40159aaf14 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -417,6 +417,10 @@ trait Row extends Serializable {
if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
return false
}
+ case d1: java.math.BigDecimal if o2.isInstanceOf[java.math.BigDecimal] =>
+ if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) {
+ return false
+ }
case _ => if (o1 != o2) {
return false
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index c666864e43..8d0c64eae4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -317,18 +317,23 @@ object CatalystTypeConverters {
private class DecimalConverter(dataType: DecimalType)
extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
- override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match {
- case d: BigDecimal => Decimal(d)
- case d: JavaBigDecimal => Decimal(d)
- case d: Decimal => d
+ override def toCatalystImpl(scalaValue: Any): Decimal = {
+ val decimal = scalaValue match {
+ case d: BigDecimal => Decimal(d)
+ case d: JavaBigDecimal => Decimal(d)
+ case d: Decimal => d
+ }
+ if (decimal.changePrecision(dataType.precision, dataType.scale)) {
+ decimal
+ } else {
+ null
+ }
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal
override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal =
row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal
}
- private object BigDecimalConverter extends DecimalConverter(DecimalType.SYSTEM_DEFAULT)
-
private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] {
final override def toScala(catalystValue: Any): Any = catalystValue
final override def toCatalystImpl(scalaValue: T): Any = scalaValue
@@ -413,8 +418,8 @@ object CatalystTypeConverters {
case s: String => StringConverter.toCatalyst(s)
case d: Date => DateConverter.toCatalyst(d)
case t: Timestamp => TimestampConverter.toCatalyst(t)
- case d: BigDecimal => BigDecimalConverter.toCatalyst(d)
- case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d)
+ case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d)
+ case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d)
case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray)
case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*)
case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 422d423747..490f3dc07b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -442,8 +442,8 @@ object HiveTypeCoercion {
* Changes numeric values to booleans so that expressions like true = 1 can be evaluated.
*/
object BooleanEquality extends Rule[LogicalPlan] {
- private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1))
- private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal(0))
+ private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE)
+ private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO)
private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = {
CaseKeyWhen(numericExpr, Seq(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 88429bb84b..39f99700c8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -26,8 +26,6 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-import scala.collection.mutable
-
object Cast {
@@ -157,7 +155,7 @@ case class Cast(child: Expression, dataType: DataType)
case ByteType =>
buildCast[Byte](_, _ != 0)
case DecimalType() =>
- buildCast[Decimal](_, _ != Decimal(0))
+ buildCast[Decimal](_, _ != Decimal.ZERO)
case DoubleType =>
buildCast[Double](_, _ != 0)
case FloatType =>
@@ -311,7 +309,7 @@ case class Cast(child: Expression, dataType: DataType)
case _: NumberFormatException => null
})
case BooleanType =>
- buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target))
+ buildCast[Boolean](_, b => changePrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
case DateType =>
buildCast[Int](_, d => null) // date can't cast to decimal in Hive
case TimestampType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 0891b55494..5808e3f66d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -511,6 +511,6 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
private def pmod(a: Decimal, n: Decimal): Decimal = {
val r = a % n
- if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r
+ if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index c0155eeb45..624c3f3d7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.types
+import java.math.{RoundingMode, MathContext}
+
import org.apache.spark.annotation.DeveloperApi
/**
@@ -28,7 +30,7 @@ import org.apache.spark.annotation.DeveloperApi
* - Otherwise, the decimal value is longVal / (10 ** _scale)
*/
final class Decimal extends Ordered[Decimal] with Serializable {
- import org.apache.spark.sql.types.Decimal.{BIG_DEC_ZERO, MAX_LONG_DIGITS, POW_10, ROUNDING_MODE}
+ import org.apache.spark.sql.types.Decimal._
private var decimalVal: BigDecimal = null
private var longVal: Long = 0L
@@ -137,9 +139,9 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def toBigDecimal: BigDecimal = {
if (decimalVal.ne(null)) {
- decimalVal
+ decimalVal(MATH_CONTEXT)
} else {
- BigDecimal(longVal, _scale)
+ BigDecimal(longVal, _scale)(MATH_CONTEXT)
}
}
@@ -261,10 +263,23 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0
- def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal)
+ def + (that: Decimal): Decimal = {
+ if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) {
+ Decimal(longVal + that.longVal, Math.max(precision, that.precision), scale)
+ } else {
+ Decimal(toBigDecimal + that.toBigDecimal, precision, scale)
+ }
+ }
- def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal)
+ def - (that: Decimal): Decimal = {
+ if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) {
+ Decimal(longVal - that.longVal, Math.max(precision, that.precision), scale)
+ } else {
+ Decimal(toBigDecimal - that.toBigDecimal, precision, scale)
+ }
+ }
+ // HiveTypeCoercion will take care of the precision, scale of result
def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal)
def / (that: Decimal): Decimal =
@@ -277,13 +292,13 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def unary_- : Decimal = {
if (decimalVal.ne(null)) {
- Decimal(-decimalVal)
+ Decimal(-decimalVal, precision, scale)
} else {
Decimal(-longVal, precision, scale)
}
}
- def abs: Decimal = if (this.compare(Decimal(0)) < 0) this.unary_- else this
+ def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this
}
object Decimal {
@@ -296,6 +311,11 @@ object Decimal {
private val BIG_DEC_ZERO = BigDecimal(0)
+ private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP)
+
+ private[sql] val ZERO = Decimal(0)
+ private[sql] val ONE = Decimal(1)
+
def apply(value: Double): Decimal = new Decimal().set(value)
def apply(value: Long): Decimal = new Decimal().set(value)
@@ -309,6 +329,9 @@ object Decimal {
def apply(value: BigDecimal, precision: Int, scale: Int): Decimal =
new Decimal().set(value, precision, scale)
+ def apply(value: java.math.BigDecimal, precision: Int, scale: Int): Decimal =
+ new Decimal().set(value, precision, scale)
+
def apply(unscaled: Long, precision: Int, scale: Int): Decimal =
new Decimal().set(unscaled, precision, scale)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
index 1d297beb38..6921d15958 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
@@ -166,6 +166,27 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
assert(Decimal(100) % Decimal(0) === null)
}
+ // regression test for SPARK-8359
+ test("accurate precision after multiplication") {
+ val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal
+ assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249")
+ }
+
+ // regression test for SPARK-8677
+ test("fix non-terminating decimal expansion problem") {
+ val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3)
+ // The difference between decimal should not be more than 0.001.
+ assert(decimal.toDouble - 0.333 < 0.001)
+ }
+
+ // regression test for SPARK-8800
+ test("fix loss of precision/scale when doing division operation") {
+ val a = Decimal(2) / Decimal(3)
+ assert(a.toDouble < 1.0 && a.toDouble > 0.6)
+ val b = Decimal(1) / Decimal(8)
+ assert(b.toDouble === 0.125)
+ }
+
test("set/setOrNull") {
assert(new Decimal().set(10L, 10, 0).toUnscaledLong === 10L)
assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
index e5bbd0aaed..e811f1de3e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
@@ -412,7 +412,8 @@ private[sql] object SparkSqlSerializer2 {
// Then, read the scale.
val scale = in.readInt()
// Finally, create the Decimal object and set it in the row.
- mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale)))
+ mutableRow.update(i,
+ Decimal(new BigDecimal(unscaledVal, scale), decimal.precision, decimal.scale))
}
}
i += 1
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index aade2e769c..dedc7c4dfb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -21,7 +21,6 @@ import java.io.OutputStream
import java.util.{List => JList, Map => JMap}
import scala.collection.JavaConversions._
-import scala.collection.JavaConverters._
import net.razorvine.pickle._
@@ -182,7 +181,7 @@ object EvaluatePython {
case (c: Double, DoubleType) => c
- case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c)
+ case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale)
case (c: Int, DateType) => c
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
index 04ab5e2217..ec5668c6b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
@@ -113,8 +113,12 @@ private[sql] object InferSchema {
case INT | LONG => LongType
// Since we do not have a data type backed by BigInteger,
// when we see a Java BigInteger, we use DecimalType.
- case BIG_INTEGER | BIG_DECIMAL => DecimalType.SYSTEM_DEFAULT
- case FLOAT | DOUBLE => DoubleType
+ case BIG_INTEGER | BIG_DECIMAL =>
+ val v = parser.getDecimalValue
+ DecimalType(v.precision(), v.scale())
+ case FLOAT | DOUBLE =>
+ // TODO(davies): Should we use decimal if possible?
+ DoubleType
}
case VALUE_TRUE | VALUE_FALSE => BooleanType
@@ -171,9 +175,18 @@ private[sql] object InferSchema {
// Double support larger range than fixed decimal, DecimalType.Maximum should be enough
// in most case, also have better precision.
case (DoubleType, t: DecimalType) =>
- if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType
+ DoubleType
case (t: DecimalType, DoubleType) =>
- if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType
+ DoubleType
+ case (t1: DecimalType, t2: DecimalType) =>
+ val scale = math.max(t1.scale, t2.scale)
+ val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
+ if (range + scale > 38) {
+ // DecimalType can't support precision > 38
+ DoubleType
+ } else {
+ DecimalType(range + scale, scale)
+ }
case (StructType(fields1), StructType(fields2)) =>
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
index bf0448ee96..f1a66c84fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
@@ -84,9 +84,8 @@ private[sql] object JacksonParser {
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) =>
parser.getDoubleValue
- case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DecimalType()) =>
- // TODO: add fixed precision and scale handling
- Decimal(parser.getDecimalValue)
+ case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) =>
+ Decimal(parser.getDecimalValue, dt.precision, dt.scale)
case (VALUE_NUMBER_INT, ByteType) =>
parser.getByteValue
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
index cb84e78d62..e912eb835d 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
@@ -164,7 +164,7 @@ public class JavaApplySchemaSuite implements Serializable {
"\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " +
"\"boolean\":false, \"null\":null}"));
List<StructField> fields = new ArrayList<StructField>(7);
- fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(38, 18),
+ fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0),
true));
fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true));
fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true));
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 037e2048a8..9bca4e7e66 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -148,7 +148,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
val dataTypes =
Seq(StringType, BinaryType, NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
- FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5),
+ FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
DateType, TimestampType,
ArrayType(IntegerType), MapType(StringType, LongType), struct)
val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index f19f22fca7..16a5c57060 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -73,8 +73,6 @@ class JsonSuite extends QueryTest with TestJsonData {
val doubleNumber: Double = 1.7976931348623157E308d
checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType))
- checkTypePromotion(
- Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.SYSTEM_DEFAULT))
checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)),
enforceCorrectType(intNumber, TimestampType))
@@ -150,7 +148,7 @@ class JsonSuite extends QueryTest with TestJsonData {
// DoubleType
checkDataType(DoubleType, DoubleType, DoubleType)
- checkDataType(DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT)
+ checkDataType(DoubleType, DecimalType.SYSTEM_DEFAULT, DoubleType)
checkDataType(DoubleType, StringType, StringType)
checkDataType(DoubleType, ArrayType(IntegerType), StringType)
checkDataType(DoubleType, StructType(Nil), StringType)
@@ -241,7 +239,7 @@ class JsonSuite extends QueryTest with TestJsonData {
val jsonDF = ctx.read.json(primitiveFieldAndType)
val expectedSchema = StructType(
- StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) ::
+ StructField("bigInteger", DecimalType(20, 0), true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", LongType, true) ::
@@ -271,7 +269,7 @@ class JsonSuite extends QueryTest with TestJsonData {
val expectedSchema = StructType(
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) ::
StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, true), true), true) ::
- StructField("arrayOfBigInteger", ArrayType(DecimalType.SYSTEM_DEFAULT, true), true) ::
+ StructField("arrayOfBigInteger", ArrayType(DecimalType(21, 0), true), true) ::
StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) ::
StructField("arrayOfDouble", ArrayType(DoubleType, true), true) ::
StructField("arrayOfInteger", ArrayType(LongType, true), true) ::
@@ -285,7 +283,7 @@ class JsonSuite extends QueryTest with TestJsonData {
StructField("field3", StringType, true) :: Nil), true), true) ::
StructField("struct", StructType(
StructField("field1", BooleanType, true) ::
- StructField("field2", DecimalType.SYSTEM_DEFAULT, true) :: Nil), true) ::
+ StructField("field2", DecimalType(20, 0), true) :: Nil), true) ::
StructField("structWithArrayFields", StructType(
StructField("field1", ArrayType(LongType, true), true) ::
StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil)
@@ -386,7 +384,7 @@ class JsonSuite extends QueryTest with TestJsonData {
val expectedSchema = StructType(
StructField("num_bool", StringType, true) ::
StructField("num_num_1", LongType, true) ::
- StructField("num_num_2", DecimalType.SYSTEM_DEFAULT, true) ::
+ StructField("num_num_2", DoubleType, true) ::
StructField("num_num_3", DoubleType, true) ::
StructField("num_str", StringType, true) ::
StructField("str_bool", StringType, true) :: Nil)
@@ -398,11 +396,9 @@ class JsonSuite extends QueryTest with TestJsonData {
checkAnswer(
sql("select * from jsonTable"),
Row("true", 11L, null, 1.1, "13.1", "str1") ::
- Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") ::
- Row("false", 21474836470L,
- new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") ::
- Row(null, 21474836570L,
- new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil
+ Row("12", null, 21474836470.9, null, null, "true") ::
+ Row("false", 21474836470L, 92233720368547758070d, 100, "str1", "false") ::
+ Row(null, 21474836570L, 1.1, 21474836470L, "92233720368547758070", null) :: Nil
)
// Number and Boolean conflict: resolve the type as number in this query.
@@ -425,8 +421,8 @@ class JsonSuite extends QueryTest with TestJsonData {
// Widening to DecimalType
checkAnswer(
sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"),
- Row(BigDecimal("21474836472.2")) ::
- Row(BigDecimal("92233720368547758071.3")) :: Nil
+ Row(21474836472.2) ::
+ Row(92233720368547758071.3) :: Nil
)
// Widening to Double
@@ -611,7 +607,7 @@ class JsonSuite extends QueryTest with TestJsonData {
val jsonDF = ctx.read.json(path)
val expectedSchema = StructType(
- StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) ::
+ StructField("bigInteger", DecimalType(20, 0), true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", LongType, true) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index a95f70f2bb..5c65a8ec57 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -189,4 +189,17 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
}
}
}
+
+ test("SPARK-9119 Decimal should be correctly written into parquet") {
+ withTempPath { dir =>
+ val basePath = dir.getCanonicalPath
+ val schema = StructType(Array(StructField("name", DecimalType(10, 5), false)))
+ val rowRDD = sqlContext.sparkContext.parallelize(Array(Row(Decimal("67123.45"))))
+ val df = sqlContext.createDataFrame(rowRDD, schema)
+ df.write.parquet(basePath)
+
+ val decimal = sqlContext.read.parquet(basePath).first().getDecimal(0)
+ assert(Decimal("67123.45") === Decimal(decimal))
+ }
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 97e4ea2081..a6a343d395 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -29,7 +29,6 @@ import org.apache.hadoop.hive.serde2.AbstractSerDe
import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.io.Writable
-import org.apache.spark.{TaskContext, Logging}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
@@ -39,6 +38,7 @@ import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.hive.{HiveContext, HiveInspectors}
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils}
+import org.apache.spark.{Logging, TaskContext}
/**
* Transforms the input by forking and running the specified script.