aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-07-23 18:31:13 -0700
committerReynold Xin <rxin@databricks.com>2015-07-23 18:31:13 -0700
commit8a94eb23d53e291441e3144a1b800fe054457040 (patch)
treed9594f55c4d07d8e8bd1da4226db5186269c9f93
parentbebe3f7b45f7b0a96f20d5af9b80633fd40cff06 (diff)
downloadspark-8a94eb23d53e291441e3144a1b800fe054457040.tar.gz
spark-8a94eb23d53e291441e3144a1b800fe054457040.tar.bz2
spark-8a94eb23d53e291441e3144a1b800fe054457040.zip
[SPARK-9069] [SPARK-9264] [SQL] remove unlimited precision support for DecimalType
Romove Decimal.Unlimited (change to support precision up to 38, to match with Hive and other databases). In order to keep backward source compatibility, Decimal.Unlimited is still there, but change to Decimal(38, 18). If no precision and scale is provide, it's Decimal(10, 0) as before. Author: Davies Liu <davies@databricks.com> Closes #7605 from davies/decimal_unlimited and squashes the following commits: aa3f115 [Davies Liu] fix tests and style fb0d20d [Davies Liu] address comments bfaae35 [Davies Liu] fix style df93657 [Davies Liu] address comments and clean up 06727fd [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_unlimited 4c28969 [Davies Liu] fix tests 8d783cc [Davies Liu] fix tests 788631c [Davies Liu] fix double with decimal in Union/except 1779bde [Davies Liu] fix scala style c9c7c78 [Davies Liu] remove Decimal.Unlimited
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala2
-rw-r--r--python/pyspark/sql/types.py36
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala255
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala110
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala54
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala45
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala46
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala8
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala57
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala9
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala4
53 files changed, 459 insertions, 473 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index c5fd2f9d5a..6355e0f179 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -218,7 +218,7 @@ class AttributeSuite extends SparkFunSuite {
// Attribute.fromStructField should accept any NumericType, not just DoubleType
val longFldWithMeta = new StructField("x", LongType, false, metadata)
assert(Attribute.fromStructField(longFldWithMeta).isNumeric)
- val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata)
+ val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata)
assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
}
}
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 10ad89ea14..b97d50c945 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -194,30 +194,33 @@ class TimestampType(AtomicType):
class DecimalType(FractionalType):
"""Decimal (decimal.Decimal) data type.
+
+ The DecimalType must have fixed precision (the maximum total number of digits)
+ and scale (the number of digits on the right of dot). For example, (5, 2) can
+ support the value from [-999.99 to 999.99].
+
+ The precision can be up to 38, the scale must less or equal to precision.
+
+ When create a DecimalType, the default precision and scale is (10, 0). When infer
+ schema from decimal.Decimal objects, it will be DecimalType(38, 18).
+
+ :param precision: the maximum total number of digits (default: 10)
+ :param scale: the number of digits on right side of dot. (default: 0)
"""
- def __init__(self, precision=None, scale=None):
+ def __init__(self, precision=10, scale=0):
self.precision = precision
self.scale = scale
- self.hasPrecisionInfo = precision is not None
+ self.hasPrecisionInfo = True # this is public API
def simpleString(self):
- if self.hasPrecisionInfo:
- return "decimal(%d,%d)" % (self.precision, self.scale)
- else:
- return "decimal(10,0)"
+ return "decimal(%d,%d)" % (self.precision, self.scale)
def jsonValue(self):
- if self.hasPrecisionInfo:
- return "decimal(%d,%d)" % (self.precision, self.scale)
- else:
- return "decimal"
+ return "decimal(%d,%d)" % (self.precision, self.scale)
def __repr__(self):
- if self.hasPrecisionInfo:
- return "DecimalType(%d,%d)" % (self.precision, self.scale)
- else:
- return "DecimalType()"
+ return "DecimalType(%d,%d)" % (self.precision, self.scale)
class DoubleType(FractionalType):
@@ -761,7 +764,10 @@ def _infer_type(obj):
return obj.__UDT__
dataType = _type_mappings.get(type(obj))
- if dataType is not None:
+ if dataType is DecimalType:
+ # the precision and scale of `obj` may be different from row to row.
+ return DecimalType(38, 18)
+ elif dataType is not None:
return dataType()
if isinstance(obj, dict):
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java
index d22ad6794d..5703de4239 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java
@@ -111,12 +111,18 @@ public class DataTypes {
return new ArrayType(elementType, containsNull);
}
+ /**
+ * Creates a DecimalType by specifying the precision and scale.
+ */
public static DecimalType createDecimalType(int precision, int scale) {
return DecimalType$.MODULE$.apply(precision, scale);
}
+ /**
+ * Creates a DecimalType with default precision and scale, which are 10 and 0.
+ */
public static DecimalType createDecimalType() {
- return DecimalType$.MODULE$.Unlimited();
+ return DecimalType$.MODULE$.USER_DEFAULT();
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 9a3f9694e4..88a457f87c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -75,7 +75,7 @@ private [sql] object JavaTypeInference {
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
- case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
+ case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 21b1de1ab9..2442341da1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -131,10 +131,10 @@ trait ScalaReflection {
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true)
- case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
+ case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[java.math.BigDecimal] =>
- Schema(DecimalType.Unlimited, nullable = true)
- case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
+ Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
+ case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
@@ -167,8 +167,8 @@ trait ScalaReflection {
case obj: Float => FloatType
case obj: Double => DoubleType
case obj: java.sql.Date => DateType
- case obj: java.math.BigDecimal => DecimalType.Unlimited
- case obj: Decimal => DecimalType.Unlimited
+ case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT
+ case obj: Decimal => DecimalType.SYSTEM_DEFAULT
case obj: java.sql.Timestamp => TimestampType
case null => NullType
// For other cases, there is no obvious mapping from the type of the given object to a
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 29cfc064da..c494e5d704 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -322,7 +322,10 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val numericLiteral: Parser[Literal] =
( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) }
- | sign.? ~ unsignedFloat ^^ { case s ~ f => Literal((s.getOrElse("") + f).toDouble) }
+ | sign.? ~ unsignedFloat ^^ {
+ // TODO(davies): some precisions may loss, we should create decimal literal
+ case s ~ f => Literal(BigDecimal(s.getOrElse("") + f).doubleValue())
+ }
)
protected lazy val unsignedFloat: Parser[String] =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index e214545726..d56ceeadc9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.analysis
import javax.annotation.Nullable
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
@@ -58,8 +60,7 @@ object HiveTypeCoercion {
IntegerType,
LongType,
FloatType,
- DoubleType,
- DecimalType.Unlimited)
+ DoubleType)
/**
* Find the tightest common type of two types that might be used in a binary expression.
@@ -72,15 +73,16 @@ object HiveTypeCoercion {
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
- // Promote numeric types to the highest of the two and all numeric types to unlimited decimal
+ case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) =>
+ Some(t2)
+ case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) =>
+ Some(t1)
+
+ // Promote numeric types to the highest of the two
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
Some(numericPrecedence(index))
- // Fixed-precision decimals can up-cast into unlimited
- case (DecimalType.Unlimited, _: DecimalType) => Some(DecimalType.Unlimited)
- case (_: DecimalType, DecimalType.Unlimited) => Some(DecimalType.Unlimited)
-
case _ => None
}
@@ -101,7 +103,7 @@ object HiveTypeCoercion {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case None => None
case Some(d) =>
- findTightestCommonTypeOfTwo(d, c).orElse(findTightestCommonTypeToString(d, c))
+ findTightestCommonTypeToString(d, c)
})
}
@@ -158,6 +160,9 @@ object HiveTypeCoercion {
* converted to DOUBLE.
* - TINYINT, SMALLINT, and INT can all be converted to FLOAT.
* - BOOLEAN types cannot be converted to any other type.
+ * - Any integral numeric type can be implicitly converted to decimal type.
+ * - two different decimal types will be converted into a wider decimal type for both of them.
+ * - decimal type will be converted into double if there float or double together with it.
*
* Additionally, all types when UNION-ed with strings will be promoted to strings.
* Other string conversions are handled by PromoteStrings.
@@ -166,55 +171,50 @@ object HiveTypeCoercion {
* - IntegerType to FloatType
* - LongType to FloatType
* - LongType to DoubleType
+ * - DecimalType to Double
+ *
+ * This rule is only applied to Union/Except/Intersect
*/
object WidenTypes extends Rule[LogicalPlan] {
- private[this] def widenOutputTypes(planName: String, left: LogicalPlan, right: LogicalPlan):
- (LogicalPlan, LogicalPlan) = {
-
- // TODO: with fixed-precision decimals
- val castedInput = left.output.zip(right.output).map {
- // When a string is found on one side, make the other side a string too.
- case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
- (lhs, Alias(Cast(rhs, StringType), rhs.name)())
- case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
- (Alias(Cast(lhs, StringType), lhs.name)(), rhs)
+ private[this] def widenOutputTypes(
+ planName: String,
+ left: LogicalPlan,
+ right: LogicalPlan): (LogicalPlan, LogicalPlan) = {
+ val castedTypes = left.output.zip(right.output).map {
case (lhs, rhs) if lhs.dataType != rhs.dataType =>
- logDebug(s"Resolving mismatched $planName input ${lhs.dataType}, ${rhs.dataType}")
- findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
- val newLeft =
- if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)()
- val newRight =
- if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)()
-
- (newLeft, newRight)
- }.getOrElse {
- // If there is no applicable conversion, leave expression unchanged.
- (lhs, rhs)
+ (lhs.dataType, rhs.dataType) match {
+ case (t1: DecimalType, t2: DecimalType) =>
+ Some(DecimalPrecision.widerDecimalType(t1, t2))
+ case (t: IntegralType, d: DecimalType) =>
+ Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
+ case (d: DecimalType, t: IntegralType) =>
+ Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
+ case (t: FractionalType, d: DecimalType) =>
+ Some(DoubleType)
+ case (d: DecimalType, t: FractionalType) =>
+ Some(DoubleType)
+ case _ =>
+ findTightestCommonTypeToString(lhs.dataType, rhs.dataType)
}
-
- case other => other
+ case other => None
}
- val (castedLeft, castedRight) = castedInput.unzip
-
- val newLeft =
- if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
- logDebug(s"Widening numeric types in $planName $castedLeft ${left.output}")
- Project(castedLeft, left)
- } else {
- left
+ def castOutput(plan: LogicalPlan): LogicalPlan = {
+ val casted = plan.output.zip(castedTypes).map {
+ case (hs, Some(dt)) if dt != hs.dataType =>
+ Alias(Cast(hs, dt), hs.name)()
+ case (hs, _) => hs
}
+ Project(casted, plan)
+ }
- val newRight =
- if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
- logDebug(s"Widening numeric types in $planName $castedRight ${right.output}")
- Project(castedRight, right)
- } else {
- right
- }
- (newLeft, newRight)
+ if (castedTypes.exists(_.isDefined)) {
+ (castOutput(left), castOutput(right))
+ } else {
+ (left, right)
+ }
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -334,144 +334,94 @@ object HiveTypeCoercion {
* - SHORT gets turned into DECIMAL(5, 0)
* - INT gets turned into DECIMAL(10, 0)
* - LONG gets turned into DECIMAL(20, 0)
- * - FLOAT and DOUBLE
- * 1. Union, Intersect and Except operations:
- * FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the
- * same as Hive)
- * 2. Other operation:
- * FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
- * but note that unlimited decimals are considered bigger than doubles in WidenTypes)
+ * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
+ *
+ * Note: Union/Except/Interact is handled by WidenTypes
*/
// scalastyle:on
object DecimalPrecision extends Rule[LogicalPlan] {
import scala.math.{max, min}
- // Conversion rules for integer types into fixed-precision decimals
- private val intTypeToFixed: Map[DataType, DecimalType] = Map(
- ByteType -> DecimalType(3, 0),
- ShortType -> DecimalType(5, 0),
- IntegerType -> DecimalType(10, 0),
- LongType -> DecimalType(20, 0)
- )
-
private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
- // Conversion rules for float and double into fixed-precision decimals
- private val floatTypeToFixed: Map[DataType, DecimalType] = Map(
- FloatType -> DecimalType(7, 7),
- DoubleType -> DecimalType(15, 15)
- )
-
- private def castDecimalPrecision(
- left: LogicalPlan,
- right: LogicalPlan): (LogicalPlan, LogicalPlan) = {
- val castedInput = left.output.zip(right.output).map {
- case (lhs, rhs) if lhs.dataType != rhs.dataType =>
- (lhs.dataType, rhs.dataType) match {
- case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
- // Decimals with precision/scale p1/s2 and p2/s2 will be promoted to
- // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
- val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2))
- (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)())
- case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
- (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs)
- case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
- (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)())
- case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) =>
- (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs)
- case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) =>
- (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)())
- case _ => (lhs, rhs)
- }
- case other => other
- }
-
- val (castedLeft, castedRight) = castedInput.unzip
+ // Returns the wider decimal type that's wider than both of them
+ def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = {
+ widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale)
+ }
+ // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
+ def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
+ val scale = max(s1, s2)
+ val range = max(p1 - s1, p2 - s2)
+ DecimalType.bounded(range + scale, scale)
+ }
- val newLeft =
- if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
- Project(castedLeft, left)
- } else {
- left
- }
+ /**
+ * An expression used to wrap the children when promote the precision of DecimalType to avoid
+ * promote multiple times.
+ */
+ case class ChangePrecision(child: Expression) extends UnaryExpression {
+ override def dataType: DataType = child.dataType
+ override def eval(input: InternalRow): Any = child.eval(input)
+ override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = ""
+ override def prettyName: String = "change_precision"
+ }
- val newRight =
- if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
- Project(castedRight, right)
- } else {
- right
- }
- (newLeft, newRight)
+ def changePrecision(e: Expression, dataType: DataType): Expression = {
+ ChangePrecision(Cast(e, dataType))
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- // fix decimal precision for union, intersect and except
- case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
- val (newLeft, newRight) = castDecimalPrecision(left, right)
- Union(newLeft, newRight)
- case i @ Intersect(left, right) if i.childrenResolved && !i.resolved =>
- val (newLeft, newRight) = castDecimalPrecision(left, right)
- Intersect(newLeft, newRight)
- case e @ Except(left, right) if e.childrenResolved && !e.resolved =>
- val (newLeft, newRight) = castDecimalPrecision(left, right)
- Except(newLeft, newRight)
-
// fix decimal precision for expressions
case q => q.transformExpressions {
// Skip nodes whose children have not been resolved yet
case e if !e.childrenResolved => e
+ // Skip nodes who is already promoted
+ case e: BinaryArithmetic if e.left.isInstanceOf[ChangePrecision] => e
+
case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- Cast(
- Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
- DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
- )
+ val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
+ Add(changePrecision(e1, dt), changePrecision(e2, dt))
case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- Cast(
- Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
- DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
- )
+ val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
+ Subtract(changePrecision(e1, dt), changePrecision(e2, dt))
case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- Cast(
- Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
- DecimalType(p1 + p2 + 1, s1 + s2)
- )
+ val dt = DecimalType.bounded(p1 + p2 + 1, s1 + s2)
+ Multiply(changePrecision(e1, dt), changePrecision(e2, dt))
case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- Cast(
- Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
- DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
- )
+ val dt = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
+ Divide(changePrecision(e1, dt), changePrecision(e2, dt))
case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- Cast(
- Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
- DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
- )
+ val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ // resultType may have lower precision, so we cast them into wider type first.
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ Cast(Remainder(changePrecision(e1, widerType), changePrecision(e2, widerType)),
+ resultType)
case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- Cast(
- Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
- DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
- )
+ val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ // resultType may have lower precision, so we cast them into wider type first.
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ Cast(Pmod(changePrecision(e1, widerType), changePrecision(e2, widerType)), resultType)
- // When we compare 2 decimal types with different precisions, cast them to the smallest
- // common precision.
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
- val resultType = DecimalType(max(p1, p2), max(s1, s2))
+ val resultType = widerDecimalType(p1, s1, p2, s2)
b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType)))
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
(left.dataType, right.dataType) match {
- case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
- b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right))
- case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
- b.makeCopy(Array(left, Cast(right, intTypeToFixed(t))))
+ case (t: IntegralType, DecimalType.Fixed(p, s)) =>
+ b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right))
+ case (DecimalType.Fixed(p, s), t: IntegralType) =>
+ b.makeCopy(Array(left, Cast(right, DecimalType.forType(t))))
case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
b.makeCopy(Array(left, Cast(right, DoubleType)))
case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
@@ -485,7 +435,6 @@ object HiveTypeCoercion {
// SUM and AVERAGE are handled by the implementations of those expressions
}
}
-
}
/**
@@ -563,7 +512,7 @@ object HiveTypeCoercion {
case e if !e.childrenResolved => e
case Cast(e @ StringType(), t: IntegralType) =>
- Cast(Cast(e, DecimalType.Unlimited), t)
+ Cast(Cast(e, DecimalType.forType(LongType)), t)
}
}
@@ -756,8 +705,8 @@ object HiveTypeCoercion {
// Implicit cast among numeric types. When we reach here, input type is not acceptable.
// If input is a numeric type but not decimal, and we expect a decimal type,
- // cast the input to unlimited precision decimal.
- case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited)
+ // cast the input to decimal.
+ case (d: NumericType, DecimalType) => Cast(e, DecimalType.forType(d))
// For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
case (_: NumericType, target: NumericType) => Cast(e, target)
@@ -766,7 +715,7 @@ object HiveTypeCoercion {
case (TimestampType, DateType) => Cast(e, DateType)
// Implicit cast from/to string
- case (StringType, DecimalType) => Cast(e, DecimalType.Unlimited)
+ case (StringType, DecimalType) => Cast(e, DecimalType.SYSTEM_DEFAULT)
case (StringType, target: NumericType) => Cast(e, target)
case (StringType, DateType) => Cast(e, DateType)
case (StringType, TimestampType) => Cast(e, TimestampType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 5182175796..a7e3a49327 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -201,7 +201,7 @@ package object dsl {
/** Creates a new AttributeReference of type decimal */
def decimal: AttributeReference =
- AttributeReference(s, DecimalType.Unlimited, nullable = true)()
+ AttributeReference(s, DecimalType.SYSTEM_DEFAULT, nullable = true)()
/** Creates a new AttributeReference of type decimal */
def decimal(precision: Int, scale: Int): AttributeReference =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index e66cd82848..c66854d52c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -300,12 +300,7 @@ case class Cast(child: Expression, dataType: DataType)
* NOTE: this modifies `value` in-place, so don't call it on external data.
*/
private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = {
- decimalType match {
- case DecimalType.Unlimited =>
- value
- case DecimalType.Fixed(precision, scale) =>
- if (value.changePrecision(precision, scale)) value else null
- }
+ if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null
}
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index b924af4cc8..88fb516e64 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -36,14 +36,13 @@ case class Average(child: Expression) extends AlgebraicAggregate {
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
private val resultType = child.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- DecimalType(precision + 4, scale + 4)
- case DecimalType.Unlimited => DecimalType.Unlimited
+ case DecimalType.Fixed(p, s) =>
+ DecimalType.bounded(p + 4, s + 4)
case _ => DoubleType
}
private val sumDataType = child.dataType match {
- case _ @ DecimalType() => DecimalType.Unlimited
+ case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case _ => DoubleType
}
@@ -71,7 +70,14 @@ case class Average(child: Expression) extends AlgebraicAggregate {
)
// If all input are nulls, currentCount will be 0 and we will get null after the division.
- override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType)
+ override val evaluateExpression = child.dataType match {
+ case DecimalType.Fixed(p, s) =>
+ // increase the precision and scale to prevent precision loss
+ val dt = DecimalType.bounded(p + 14, s + 4)
+ Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType)
+ case _ =>
+ Cast(currentSum, resultType) / Cast(currentCount, resultType)
+ }
}
case class Count(child: Expression) extends AlgebraicAggregate {
@@ -255,15 +261,11 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
private val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
- DecimalType(precision + 4, scale + 4)
- case DecimalType.Unlimited => DecimalType.Unlimited
+ DecimalType.bounded(precision + 10, scale)
case _ => child.dataType
}
- private val sumDataType = child.dataType match {
- case _ @ DecimalType() => DecimalType.Unlimited
- case _ => child.dataType
- }
+ private val sumDataType = resultType
private val currentSum = AttributeReference("currentSum", sumDataType)()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index d3295b8baf..73fde4e916 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -390,22 +390,21 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg
override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
- DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive
- case DecimalType.Unlimited =>
- DecimalType.Unlimited
+ // Add 4 digits after decimal point, like Hive
+ DecimalType.bounded(precision + 4, scale + 4)
case _ =>
DoubleType
}
override def asPartial: SplitEvaluation = {
child.dataType match {
- case DecimalType.Fixed(_, _) | DecimalType.Unlimited =>
- // Turn the child to unlimited decimals for calculation, before going back to fixed
- val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
+ case DecimalType.Fixed(precision, scale) =>
+ val partialSum = Alias(Sum(child), "PartialSum")()
val partialCount = Alias(Count(child), "PartialCount")()
- val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited)
- val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited)
+ // partialSum already increase the precision by 10
+ val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType)
+ val castedCount = Sum(partialCount.toAttribute)
SplitEvaluation(
Cast(Divide(castedSum, castedCount), dataType),
partialCount :: partialSum :: Nil)
@@ -435,8 +434,8 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1)
private val calcType =
expr.dataType match {
- case DecimalType.Fixed(_, _) =>
- DecimalType.Unlimited
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType.bounded(precision + 10, scale)
case _ =>
expr.dataType
}
@@ -454,10 +453,9 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1)
null
} else {
expr.dataType match {
- case DecimalType.Fixed(_, _) =>
- Cast(Divide(
- Cast(sum, DecimalType.Unlimited),
- Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null)
+ case DecimalType.Fixed(precision, scale) =>
+ val dt = DecimalType.bounded(precision + 14, scale + 4)
+ Cast(Divide(Cast(sum, dt), Cast(Literal(count), dt)), dataType).eval(null)
case _ =>
Divide(
Cast(sum, dataType),
@@ -481,9 +479,8 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1
override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
- DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
- case DecimalType.Unlimited =>
- DecimalType.Unlimited
+ // Add 10 digits left of decimal point, like Hive
+ DecimalType.bounded(precision + 10, scale)
case _ =>
child.dataType
}
@@ -491,7 +488,7 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1
override def asPartial: SplitEvaluation = {
child.dataType match {
case DecimalType.Fixed(_, _) =>
- val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
+ val partialSum = Alias(Sum(child), "PartialSum")()
SplitEvaluation(
Cast(CombineSum(partialSum.toAttribute), dataType),
partialSum :: Nil)
@@ -515,8 +512,8 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg
private val calcType =
expr.dataType match {
- case DecimalType.Fixed(_, _) =>
- DecimalType.Unlimited
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType.bounded(precision + 10, scale)
case _ =>
expr.dataType
}
@@ -572,8 +569,8 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression1)
private val calcType =
expr.dataType match {
- case DecimalType.Fixed(_, _) =>
- DecimalType.Unlimited
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType.bounded(precision + 10, scale)
case _ =>
expr.dataType
}
@@ -608,9 +605,8 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg
override def nullable: Boolean = true
override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
- DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
- case DecimalType.Unlimited =>
- DecimalType.Unlimited
+ // Add 10 digits left of decimal point, like Hive
+ DecimalType.bounded(precision + 10, scale)
case _ =>
child.dataType
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 05b5ad88fe..7c254a8750 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -88,6 +88,8 @@ abstract class BinaryArithmetic extends BinaryOperator {
override def dataType: DataType = left.dataType
+ override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess
+
/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String =
sys.error("BinaryArithmetics must override either decimalMethod or genCode")
@@ -114,9 +116,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "+"
- override lazy val resolved =
- childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
-
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
@@ -146,9 +145,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
override def symbol: String = "-"
- override lazy val resolved =
- childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
-
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
@@ -179,9 +175,6 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
override def symbol: String = "*"
override def decimalMethod: String = "$times"
- override lazy val resolved =
- childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
-
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
@@ -195,9 +188,6 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
override def decimalMethod: String = "$div"
override def nullable: Boolean = true
- override lazy val resolved =
- childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
-
private lazy val div: (Any, Any) => Any = dataType match {
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot
@@ -260,9 +250,6 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
override def decimalMethod: String = "remainder"
override def nullable: Boolean = true
- override lazy val resolved =
- childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
-
private lazy val integral = dataType match {
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index f25ac32679..85060b7893 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -36,9 +36,9 @@ object Literal {
case s: Short => Literal(s, ShortType)
case s: String => Literal(UTF8String.fromString(s), StringType)
case b: Boolean => Literal(b, BooleanType)
- case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
- case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
- case d: Decimal => Literal(d, DecimalType.Unlimited)
+ case d: BigDecimal => Literal(Decimal(d), DecimalType(d.precision, d.scale))
+ case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType(d.precision(), d.scale()))
+ case d: Decimal => Literal(d, DecimalType(d.precision, d.scale))
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case a: Array[Byte] => Literal(a, BinaryType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index d06a7a2add..c610f70d38 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -17,9 +17,9 @@
package org.apache.spark.sql.catalyst.plans
-import org.apache.spark.sql.catalyst.expressions.{VirtualColumn, Attribute, AttributeSet, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn}
import org.apache.spark.sql.catalyst.trees.TreeNode
-import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType}
+import org.apache.spark.sql.types.{DataType, StructType}
abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] {
self: PlanType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index e98fd2583b..591fb26e67 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -106,7 +106,7 @@ object DataType {
private def nameToType(name: String): DataType = {
val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r
name match {
- case "decimal" => DecimalType.Unlimited
+ case "decimal" => DecimalType.USER_DEFAULT
case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
case other => nonDecimalNameToType(other)
}
@@ -177,7 +177,7 @@ object DataType {
| "BinaryType" ^^^ BinaryType
| "BooleanType" ^^^ BooleanType
| "DateType" ^^^ DateType
- | "DecimalType()" ^^^ DecimalType.Unlimited
+ | "DecimalType()" ^^^ DecimalType.USER_DEFAULT
| fixedDecimalType
| "TimestampType" ^^^ TimestampType
)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala
index 6b43224feb..6e081ea923 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala
@@ -48,7 +48,7 @@ private[sql] trait DataTypeParser extends StandardTokenParsers {
"(?i)binary".r ^^^ BinaryType |
"(?i)boolean".r ^^^ BooleanType |
fixedDecimalType |
- "(?i)decimal".r ^^^ DecimalType.Unlimited |
+ "(?i)decimal".r ^^^ DecimalType.USER_DEFAULT |
"(?i)date".r ^^^ DateType |
"(?i)timestamp".r ^^^ TimestampType |
varchar
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 377c75f6e8..26b24616d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -26,25 +26,46 @@ import org.apache.spark.sql.catalyst.expressions.Expression
/** Precision parameters for a Decimal */
+@deprecated("Use DecimalType(precision, scale) directly", "1.5")
case class PrecisionInfo(precision: Int, scale: Int) {
if (scale > precision) {
throw new AnalysisException(
s"Decimal scale ($scale) cannot be greater than precision ($precision).")
}
+ if (precision > DecimalType.MAX_PRECISION) {
+ throw new AnalysisException(
+ s"DecimalType can only support precision up to 38"
+ )
+ }
}
/**
* :: DeveloperApi ::
* The data type representing `java.math.BigDecimal` values.
- * A Decimal that might have fixed precision and scale, or unlimited values for these.
+ * A Decimal that must have fixed precision (the maximum number of digits) and scale (the number
+ * of digits on right side of dot).
+ *
+ * The precision can be up to 38, scale can also be up to 38 (less or equal to precision).
+ *
+ * The default precision and scale is (10, 0).
*
* Please use [[DataTypes.createDecimalType()]] to create a specific instance.
*/
@DeveloperApi
-case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType {
+case class DecimalType(precision: Int, scale: Int) extends FractionalType {
+
+ // default constructor for Java
+ def this(precision: Int) = this(precision, 0)
+ def this() = this(10)
+
+ @deprecated("Use DecimalType(precision, scale) instead", "1.5")
+ def this(precisionInfo: Option[PrecisionInfo]) {
+ this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision,
+ precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale)
+ }
- /** No-arg constructor for kryo. */
- protected def this() = this(null)
+ @deprecated("Use DecimalType.precision and DecimalType.scale instead", "1.5")
+ val precisionInfo = Some(PrecisionInfo(precision, scale))
private[sql] type InternalType = Decimal
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
@@ -53,18 +74,16 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
private[sql] val ordering = Decimal.DecimalIsFractional
private[sql] val asIntegral = Decimal.DecimalAsIfIntegral
- def precision: Int = precisionInfo.map(_.precision).getOrElse(-1)
-
- def scale: Int = precisionInfo.map(_.scale).getOrElse(-1)
+ override def typeName: String = s"decimal($precision,$scale)"
- override def typeName: String = precisionInfo match {
- case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
- case None => "decimal"
- }
+ override def toString: String = s"DecimalType($precision,$scale)"
- override def toString: String = precisionInfo match {
- case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)"
- case None => "DecimalType()"
+ private[sql] def isWiderThan(other: DataType): Boolean = other match {
+ case dt: DecimalType =>
+ (precision - scale) >= (dt.precision - dt.scale) && scale >= dt.scale
+ case dt: IntegralType =>
+ isWiderThan(DecimalType.forType(dt))
+ case _ => false
}
/**
@@ -72,10 +91,7 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
*/
override def defaultSize: Int = 4096
- override def simpleString: String = precisionInfo match {
- case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
- case None => "decimal(10,0)"
- }
+ override def simpleString: String = s"decimal($precision,$scale)"
private[spark] override def asNullable: DecimalType = this
}
@@ -83,8 +99,47 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
/** Extra factory methods and pattern matchers for Decimals */
object DecimalType extends AbstractDataType {
+ import scala.math.min
+
+ val MAX_PRECISION = 38
+ val MAX_SCALE = 38
+ val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18)
+ val USER_DEFAULT: DecimalType = DecimalType(10, 0)
+
+ @deprecated("Does not support unlimited precision, please specify the precision and scale", "1.5")
+ val Unlimited: DecimalType = SYSTEM_DEFAULT
+
+ // The decimal types compatible with other numberic types
+ private[sql] val ByteDecimal = DecimalType(3, 0)
+ private[sql] val ShortDecimal = DecimalType(5, 0)
+ private[sql] val IntDecimal = DecimalType(10, 0)
+ private[sql] val LongDecimal = DecimalType(20, 0)
+ private[sql] val FloatDecimal = DecimalType(14, 7)
+ private[sql] val DoubleDecimal = DecimalType(30, 15)
+
+ private[sql] def forType(dataType: DataType): DecimalType = dataType match {
+ case ByteType => ByteDecimal
+ case ShortType => ShortDecimal
+ case IntegerType => IntDecimal
+ case LongType => LongDecimal
+ case FloatType => FloatDecimal
+ case DoubleType => DoubleDecimal
+ }
- override private[sql] def defaultConcreteType: DataType = Unlimited
+ @deprecated("please specify precision and scale", "1.5")
+ def apply(): DecimalType = USER_DEFAULT
+
+ @deprecated("Use DecimalType(precision, scale) instead", "1.5")
+ def apply(precisionInfo: Option[PrecisionInfo]) {
+ this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision,
+ precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale)
+ }
+
+ private[sql] def bounded(precision: Int, scale: Int): DecimalType = {
+ DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE))
+ }
+
+ override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[DecimalType]
@@ -92,31 +147,18 @@ object DecimalType extends AbstractDataType {
override private[sql] def simpleString: String = "decimal"
- val Unlimited: DecimalType = DecimalType(None)
-
private[sql] object Fixed {
- def unapply(t: DecimalType): Option[(Int, Int)] =
- t.precisionInfo.map(p => (p.precision, p.scale))
+ def unapply(t: DecimalType): Option[(Int, Int)] = Some((t.precision, t.scale))
}
private[sql] object Expression {
def unapply(e: Expression): Option[(Int, Int)] = e.dataType match {
- case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale))
+ case t: DecimalType => Some((t.precision, t.scale))
case _ => None
}
}
- def apply(): DecimalType = Unlimited
-
- def apply(precision: Int, scale: Int): DecimalType =
- DecimalType(Some(PrecisionInfo(precision, scale)))
-
def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType]
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType]
-
- def isFixed(dataType: DataType): Boolean = dataType match {
- case DecimalType.Fixed(_, _) => true
- case _ => false
- }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index 13aad467fa..b9f2ad7ec0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -94,8 +94,8 @@ object RandomDataGenerator {
case BooleanType => Some(() => rand.nextBoolean())
case DateType => Some(() => new java.sql.Date(rand.nextInt()))
case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong()))
- case DecimalType.Unlimited => Some(
- () => BigDecimal.apply(rand.nextLong, rand.nextInt, MathContext.UNLIMITED))
+ case DecimalType.Fixed(precision, scale) => Some(
+ () => BigDecimal.apply(rand.nextLong, rand.nextInt, new MathContext(precision)))
case DoubleType => randomNumeric[Double](
rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue,
Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala
index dbba93dba6..677ba0a180 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala
@@ -50,9 +50,7 @@ class RandomDataGeneratorSuite extends SparkFunSuite {
for (
dataType <- DataTypeTestUtils.atomicTypes;
nullable <- Seq(true, false)
- if !dataType.isInstanceOf[DecimalType] ||
- dataType.asInstanceOf[DecimalType].precisionInfo.isEmpty
- ) {
+ if !dataType.isInstanceOf[DecimalType]) {
test(s"$dataType (nullable=$nullable)") {
testRandomDataGeneration(dataType)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index b4b00f5584..3b848cfdf7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -102,7 +102,7 @@ class ScalaReflectionSuite extends SparkFunSuite {
StructField("byteField", ByteType, nullable = true),
StructField("booleanField", BooleanType, nullable = true),
StructField("stringField", StringType, nullable = true),
- StructField("decimalField", DecimalType.Unlimited, nullable = true),
+ StructField("decimalField", DecimalType.SYSTEM_DEFAULT, nullable = true),
StructField("dateField", DateType, nullable = true),
StructField("timestampField", TimestampType, nullable = true),
StructField("binaryField", BinaryType, nullable = true))),
@@ -216,7 +216,7 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(DoubleType === typeOfObject(1.7976931348623157E308))
// DecimalType
- assert(DecimalType.Unlimited ===
+ assert(DecimalType.SYSTEM_DEFAULT ===
typeOfObject(new java.math.BigDecimal("1.7976931348623157E318")))
// DateType
@@ -229,19 +229,19 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(NullType === typeOfObject(null))
def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse {
- case value: java.math.BigInteger => DecimalType.Unlimited
- case value: java.math.BigDecimal => DecimalType.Unlimited
+ case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT
+ case value: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT
case _ => StringType
}
- assert(DecimalType.Unlimited === typeOfObject1(
+ assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1(
new BigInteger("92233720368547758070")))
- assert(DecimalType.Unlimited === typeOfObject1(
+ assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1(
new java.math.BigDecimal("1.7976931348623157E318")))
assert(StringType === typeOfObject1(BigInt("92233720368547758070")))
def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse {
- case value: java.math.BigInteger => DecimalType.Unlimited
+ case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT
}
intercept[MatchError](typeOfObject2(BigInt("92233720368547758070")))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 58df1de983..7e67427237 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -55,7 +55,7 @@ object AnalysisSuite {
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
- AttributeReference("d", DecimalType.Unlimited)(),
+ AttributeReference("d", DecimalType.SYSTEM_DEFAULT)(),
AttributeReference("e", ShortType)())
val nestedRelation = LocalRelation(
@@ -158,7 +158,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
- AttributeReference("d", DecimalType.Unlimited)(),
+ AttributeReference("d", DecimalType(10, 2))(),
AttributeReference("e", ShortType)())
val plan = caseInsensitiveAnalyzer.execute(
@@ -173,7 +173,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
assert(pl(0).dataType == DoubleType)
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
- assert(pl(3).dataType == DecimalType.Unlimited)
+ assert(pl(3).dataType == DoubleType) // StringType will be promoted into Double
assert(pl(4).dataType == DoubleType)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 7bac97b789..f9f15e7a66 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -34,7 +34,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
AttributeReference("i", IntegerType)(),
AttributeReference("d1", DecimalType(2, 1))(),
AttributeReference("d2", DecimalType(5, 2))(),
- AttributeReference("u", DecimalType.Unlimited)(),
+ AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
AttributeReference("f", FloatType)(),
AttributeReference("b", DoubleType)()
)
@@ -92,11 +92,11 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
}
test("Comparison operations") {
- checkComparison(EqualTo(i, d1), DecimalType(10, 1))
+ checkComparison(EqualTo(i, d1), DecimalType(11, 1))
checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2))
- checkComparison(LessThan(i, d1), DecimalType(10, 1))
+ checkComparison(LessThan(i, d1), DecimalType(11, 1))
checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2))
- checkComparison(GreaterThan(d2, u), DecimalType.Unlimited)
+ checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT)
checkComparison(GreaterThanOrEqual(d1, f), DoubleType)
checkComparison(GreaterThan(d2, d2), DecimalType(5, 2))
}
@@ -106,12 +106,12 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
checkUnion(i, d2, DecimalType(12, 2))
checkUnion(d1, d2, DecimalType(5, 2))
checkUnion(d2, d1, DecimalType(5, 2))
- checkUnion(d1, f, DecimalType(8, 7))
- checkUnion(f, d2, DecimalType(10, 7))
- checkUnion(d1, b, DecimalType(16, 15))
- checkUnion(b, d2, DecimalType(18, 15))
- checkUnion(d1, u, DecimalType.Unlimited)
- checkUnion(u, d2, DecimalType.Unlimited)
+ checkUnion(d1, f, DoubleType)
+ checkUnion(f, d2, DoubleType)
+ checkUnion(d1, b, DoubleType)
+ checkUnion(b, d2, DoubleType)
+ checkUnion(d1, u, DecimalType.SYSTEM_DEFAULT)
+ checkUnion(u, d2, DecimalType.SYSTEM_DEFAULT)
}
test("bringing in primitive types") {
@@ -125,13 +125,33 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
checkType(Add(d1, Cast(i, DoubleType)), DoubleType)
}
- test("unlimited decimals make everything else cast up") {
- for (expr <- Seq(d1, d2, i, f, u)) {
- checkType(Add(expr, u), DecimalType.Unlimited)
- checkType(Subtract(expr, u), DecimalType.Unlimited)
- checkType(Multiply(expr, u), DecimalType.Unlimited)
- checkType(Divide(expr, u), DecimalType.Unlimited)
- checkType(Remainder(expr, u), DecimalType.Unlimited)
+ test("maximum decimals") {
+ for (expr <- Seq(d1, d2, i, u)) {
+ checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT)
+ checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT)
+ }
+
+ checkType(Multiply(d1, u), DecimalType(38, 19))
+ checkType(Multiply(d2, u), DecimalType(38, 20))
+ checkType(Multiply(i, u), DecimalType(38, 18))
+ checkType(Multiply(u, u), DecimalType(38, 36))
+
+ checkType(Divide(u, d1), DecimalType(38, 21))
+ checkType(Divide(u, d2), DecimalType(38, 24))
+ checkType(Divide(u, i), DecimalType(38, 29))
+ checkType(Divide(u, u), DecimalType(38, 38))
+
+ checkType(Remainder(d1, u), DecimalType(19, 18))
+ checkType(Remainder(d2, u), DecimalType(21, 18))
+ checkType(Remainder(i, u), DecimalType(28, 18))
+ checkType(Remainder(u, u), DecimalType.SYSTEM_DEFAULT)
+
+ for (expr <- Seq(f, b)) {
+ checkType(Add(expr, u), DoubleType)
+ checkType(Subtract(expr, u), DoubleType)
+ checkType(Multiply(expr, u), DoubleType)
+ checkType(Divide(expr, u), DoubleType)
+ checkType(Remainder(expr, u), DoubleType)
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 835220c563..d0fb95b580 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -35,14 +35,14 @@ class HiveTypeCoercionSuite extends PlanTest {
shouldCast(NullType, NullType, NullType)
shouldCast(NullType, IntegerType, IntegerType)
- shouldCast(NullType, DecimalType, DecimalType.Unlimited)
+ shouldCast(NullType, DecimalType, DecimalType.SYSTEM_DEFAULT)
shouldCast(ByteType, IntegerType, IntegerType)
shouldCast(IntegerType, IntegerType, IntegerType)
shouldCast(IntegerType, LongType, LongType)
- shouldCast(IntegerType, DecimalType, DecimalType.Unlimited)
+ shouldCast(IntegerType, DecimalType, DecimalType(10, 0))
shouldCast(LongType, IntegerType, IntegerType)
- shouldCast(LongType, DecimalType, DecimalType.Unlimited)
+ shouldCast(LongType, DecimalType, DecimalType(20, 0))
shouldCast(DateType, TimestampType, TimestampType)
shouldCast(TimestampType, DateType, DateType)
@@ -71,8 +71,8 @@ class HiveTypeCoercionSuite extends PlanTest {
shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType)
shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType)
- shouldCast(
- DecimalType.Unlimited, TypeCollection(IntegerType, DecimalType), DecimalType.Unlimited)
+ shouldCast(DecimalType.SYSTEM_DEFAULT,
+ TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT)
shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2))
shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2))
shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2))
@@ -82,7 +82,7 @@ class HiveTypeCoercionSuite extends PlanTest {
// NumericType should not be changed when function accepts any of them.
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
- DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe =>
+ DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2)).foreach { tpe =>
shouldCast(tpe, NumericType, tpe)
}
@@ -107,8 +107,8 @@ class HiveTypeCoercionSuite extends PlanTest {
shouldNotCast(IntegerType, TimestampType)
shouldNotCast(LongType, DateType)
shouldNotCast(LongType, TimestampType)
- shouldNotCast(DecimalType.Unlimited, DateType)
- shouldNotCast(DecimalType.Unlimited, TimestampType)
+ shouldNotCast(DecimalType.SYSTEM_DEFAULT, DateType)
+ shouldNotCast(DecimalType.SYSTEM_DEFAULT, TimestampType)
shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType))
@@ -160,14 +160,6 @@ class HiveTypeCoercionSuite extends PlanTest {
widenTest(LongType, FloatType, Some(FloatType))
widenTest(LongType, DoubleType, Some(DoubleType))
- // Casting up to unlimited-precision decimal
- widenTest(IntegerType, DecimalType.Unlimited, Some(DecimalType.Unlimited))
- widenTest(DoubleType, DecimalType.Unlimited, Some(DecimalType.Unlimited))
- widenTest(DecimalType(3, 2), DecimalType.Unlimited, Some(DecimalType.Unlimited))
- widenTest(DecimalType.Unlimited, IntegerType, Some(DecimalType.Unlimited))
- widenTest(DecimalType.Unlimited, DoubleType, Some(DecimalType.Unlimited))
- widenTest(DecimalType.Unlimited, DecimalType(3, 2), Some(DecimalType.Unlimited))
-
// No up-casting for fixed-precision decimal (this is handled by arithmetic rules)
widenTest(DecimalType(2, 1), DecimalType(3, 2), None)
widenTest(DecimalType(2, 1), DoubleType, None)
@@ -242,9 +234,9 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Literal(1)
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
:: Nil),
- Coalesce(Cast(Literal(1L), DecimalType())
- :: Cast(Literal(1), DecimalType())
- :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType())
+ Coalesce(Cast(Literal(1L), DecimalType(22, 0))
+ :: Cast(Literal(1), DecimalType(22, 0))
+ :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
:: Nil))
}
@@ -323,7 +315,7 @@ class HiveTypeCoercionSuite extends PlanTest {
val left = LocalRelation(
AttributeReference("i", IntegerType)(),
- AttributeReference("u", DecimalType.Unlimited)(),
+ AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
AttributeReference("b", ByteType)(),
AttributeReference("d", DoubleType)())
val right = LocalRelation(
@@ -333,7 +325,7 @@ class HiveTypeCoercionSuite extends PlanTest {
AttributeReference("l", LongType)())
val wt = HiveTypeCoercion.WidenTypes
- val expectedTypes = Seq(StringType, DecimalType.Unlimited, FloatType, DoubleType)
+ val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
val r1 = wt(Union(left, right)).asInstanceOf[Union]
val r2 = wt(Except(left, right)).asInstanceOf[Except]
@@ -353,13 +345,13 @@ class HiveTypeCoercionSuite extends PlanTest {
}
}
- val dp = HiveTypeCoercion.DecimalPrecision
+ val dp = HiveTypeCoercion.WidenTypes
val left1 = LocalRelation(
AttributeReference("l", DecimalType(10, 8))())
val right1 = LocalRelation(
AttributeReference("r", DecimalType(5, 5))())
- val expectedType1 = Seq(DecimalType(math.max(8, 5) + math.max(10 - 8, 5 - 5), math.max(8, 5)))
+ val expectedType1 = Seq(DecimalType(10, 8))
val r1 = dp(Union(left1, right1)).asInstanceOf[Union]
val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
@@ -372,12 +364,11 @@ class HiveTypeCoercionSuite extends PlanTest {
checkOutput(r3.left, expectedType1)
checkOutput(r3.right, expectedType1)
- val plan1 = LocalRelation(
- AttributeReference("l", DecimalType(10, 10))())
+ val plan1 = LocalRelation(AttributeReference("l", DecimalType(10, 5))())
val rightTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
- val expectedTypes = Seq(DecimalType(3, 0), DecimalType(5, 0), DecimalType(10, 0),
- DecimalType(20, 0), DecimalType(7, 7), DecimalType(15, 15))
+ val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5),
+ DecimalType(25, 5), DoubleType, DoubleType)
rightTypes.zip(expectedTypes).map { case (rType, expectedType) =>
val plan2 = LocalRelation(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index ccf448eee0..facf65c155 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -185,7 +185,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkCast(1, 1.0)
checkCast(123, "123")
- checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123))
+ checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
checkEvaluation(cast(123, DecimalType(3, 1)), null)
checkEvaluation(cast(123, DecimalType(2, 0)), null)
@@ -203,7 +203,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkCast(1L, 1.0)
checkCast(123L, "123")
- checkEvaluation(cast(123L, DecimalType.Unlimited), Decimal(123))
+ checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0))
@@ -225,7 +225,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong)
checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong)
- checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123))
+ checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
checkEvaluation(cast(123, DecimalType(3, 1)), null)
checkEvaluation(cast(123, DecimalType(2, 0)), null)
@@ -267,7 +267,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(cast("abcdef", IntegerType).nullable === true)
assert(cast("abcdef", ShortType).nullable === true)
assert(cast("abcdef", ByteType).nullable === true)
- assert(cast("abcdef", DecimalType.Unlimited).nullable === true)
+ assert(cast("abcdef", DecimalType.USER_DEFAULT).nullable === true)
assert(cast("abcdef", DecimalType(4, 2)).nullable === true)
assert(cast("abcdef", DoubleType).nullable === true)
assert(cast("abcdef", FloatType).nullable === true)
@@ -291,9 +291,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
c.getTimeInMillis * 1000)
checkEvaluation(cast("abdef", StringType), "abdef")
- checkEvaluation(cast("abdef", DecimalType.Unlimited), null)
+ checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null)
checkEvaluation(cast("abdef", TimestampType), null)
- checkEvaluation(cast("12.65", DecimalType.Unlimited), Decimal(12.65))
+ checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65))
checkEvaluation(cast(cast(sd, DateType), StringType), sd)
checkEvaluation(cast(cast(d, StringType), DateType), 0)
@@ -311,20 +311,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
5.toLong)
checkEvaluation(
cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType),
- DecimalType.Unlimited), LongType), StringType), ShortType),
+ DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType),
0.toShort)
checkEvaluation(
cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType),
- DecimalType.Unlimited), LongType), StringType), ShortType),
+ DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType),
null)
- checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.Unlimited),
+ checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT),
ByteType), TimestampType), LongType), StringType), ShortType),
0.toShort)
checkEvaluation(cast("23", DoubleType), 23d)
checkEvaluation(cast("23", IntegerType), 23)
checkEvaluation(cast("23", FloatType), 23f)
- checkEvaluation(cast("23", DecimalType.Unlimited), Decimal(23))
+ checkEvaluation(cast("23", DecimalType.USER_DEFAULT), Decimal(23))
checkEvaluation(cast("23", ByteType), 23.toByte)
checkEvaluation(cast("23", ShortType), 23.toShort)
checkEvaluation(cast("2012-12-11", DoubleType), null)
@@ -338,7 +338,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Add(Literal(23d), cast(true, DoubleType)), 24d)
checkEvaluation(Add(Literal(23), cast(true, IntegerType)), 24)
checkEvaluation(Add(Literal(23f), cast(true, FloatType)), 24f)
- checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.Unlimited)), Decimal(24))
+ checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.USER_DEFAULT)), Decimal(24))
checkEvaluation(Add(Literal(23.toByte), cast(true, ByteType)), 24.toByte)
checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort)
}
@@ -362,10 +362,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
// - Values that would overflow the target precision should turn into null
// - Because of this, casts to fixed-precision decimals should be nullable
- assert(cast(123, DecimalType.Unlimited).nullable === false)
- assert(cast(10.03f, DecimalType.Unlimited).nullable === true)
- assert(cast(10.03, DecimalType.Unlimited).nullable === true)
- assert(cast(Decimal(10.03), DecimalType.Unlimited).nullable === false)
+ assert(cast(123, DecimalType.USER_DEFAULT).nullable === true)
+ assert(cast(10.03f, DecimalType.SYSTEM_DEFAULT).nullable === true)
+ assert(cast(10.03, DecimalType.SYSTEM_DEFAULT).nullable === true)
+ assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === true)
assert(cast(123, DecimalType(2, 1)).nullable === true)
assert(cast(10.03f, DecimalType(2, 1)).nullable === true)
@@ -373,7 +373,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true)
- checkEvaluation(cast(10.03, DecimalType.Unlimited), Decimal(10.03))
+ checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03))
checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03))
checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0))
checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10))
@@ -383,7 +383,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(Decimal(10.03), DecimalType(3, 1)), Decimal(10.0))
checkEvaluation(cast(Decimal(10.03), DecimalType(3, 2)), null)
- checkEvaluation(cast(10.05, DecimalType.Unlimited), Decimal(10.05))
+ checkEvaluation(cast(10.05, DecimalType.SYSTEM_DEFAULT), Decimal(10.05))
checkEvaluation(cast(10.05, DecimalType(4, 2)), Decimal(10.05))
checkEvaluation(cast(10.05, DecimalType(3, 1)), Decimal(10.1))
checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10))
@@ -409,10 +409,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0))
checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null)
- checkEvaluation(cast(Double.NaN, DecimalType.Unlimited), null)
- checkEvaluation(cast(1.0 / 0.0, DecimalType.Unlimited), null)
- checkEvaluation(cast(Float.NaN, DecimalType.Unlimited), null)
- checkEvaluation(cast(1.0f / 0.0f, DecimalType.Unlimited), null)
+ checkEvaluation(cast(Double.NaN, DecimalType.SYSTEM_DEFAULT), null)
+ checkEvaluation(cast(1.0 / 0.0, DecimalType.SYSTEM_DEFAULT), null)
+ checkEvaluation(cast(Float.NaN, DecimalType.SYSTEM_DEFAULT), null)
+ checkEvaluation(cast(1.0f / 0.0f, DecimalType.SYSTEM_DEFAULT), null)
checkEvaluation(cast(Double.NaN, DecimalType(2, 1)), null)
checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null)
@@ -427,7 +427,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(d, LongType), null)
checkEvaluation(cast(d, FloatType), null)
checkEvaluation(cast(d, DoubleType), null)
- checkEvaluation(cast(d, DecimalType.Unlimited), null)
+ checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null)
checkEvaluation(cast(d, DecimalType(10, 2)), null)
checkEvaluation(cast(d, StringType), "1970-01-01")
checkEvaluation(cast(cast(d, TimestampType), StringType), "1970-01-01 00:00:00")
@@ -454,7 +454,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
cast(cast(millis.toDouble / 1000, TimestampType), DoubleType),
millis.toDouble / 1000)
checkEvaluation(
- cast(cast(Decimal(1), TimestampType), DecimalType.Unlimited),
+ cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT),
Decimal(1))
// A test for higher precision than millis
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index afa143bd5f..b31d6661c8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -60,7 +60,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
testIf(_.toFloat, FloatType)
testIf(_.toDouble, DoubleType)
- testIf(Decimal(_), DecimalType.Unlimited)
+ testIf(Decimal(_), DecimalType.USER_DEFAULT)
testIf(identity, DateType)
testIf(_.toLong, TimestampType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index d924ff7a10..f6404d2161 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -33,7 +33,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.create(null, LongType), null)
checkEvaluation(Literal.create(null, StringType), null)
checkEvaluation(Literal.create(null, BinaryType), null)
- checkEvaluation(Literal.create(null, DecimalType()), null)
+ checkEvaluation(Literal.create(null, DecimalType.USER_DEFAULT), null)
checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null)
checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null)
checkEvaluation(Literal.create(null, StructType(Seq.empty)), null)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
index 0728f6695c..9efe44c832 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
@@ -30,7 +30,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testFunc(1L, LongType)
testFunc(1.0F, FloatType)
testFunc(1.0, DoubleType)
- testFunc(Decimal(1.5), DecimalType.Unlimited)
+ testFunc(Decimal(1.5), DecimalType(2, 1))
testFunc(new java.sql.Date(10), DateType)
testFunc(new java.sql.Timestamp(10), TimestampType)
testFunc("abcd", StringType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
index 7566cb59e3..48b7dc5745 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
@@ -121,6 +121,8 @@ class UnsafeFixedWidthAggregationMapSuite
}.toSet
seenKeys.size should be (groupKeys.size)
seenKeys should be (groupKeys)
+
+ map.free()
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 8819234e78..a5d9806c20 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -145,7 +145,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
DoubleType,
StringType,
BinaryType
- // DecimalType.Unlimited,
+ // DecimalType.Default,
// ArrayType(IntegerType)
)
val converter = new UnsafeRowConverter(fieldTypes)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala
index c6171b7b69..1ba290753c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala
@@ -44,7 +44,7 @@ class DataTypeParserSuite extends SparkFunSuite {
checkDataType("float", FloatType)
checkDataType("dOUBle", DoubleType)
checkDataType("decimal(10, 5)", DecimalType(10, 5))
- checkDataType("decimal", DecimalType.Unlimited)
+ checkDataType("decimal", DecimalType.USER_DEFAULT)
checkDataType("DATE", DateType)
checkDataType("timestamp", TimestampType)
checkDataType("string", StringType)
@@ -87,7 +87,7 @@ class DataTypeParserSuite extends SparkFunSuite {
StructType(
StructField("struct",
StructType(
- StructField("deciMal", DecimalType.Unlimited, true) ::
+ StructField("deciMal", DecimalType.USER_DEFAULT, true) ::
StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) ::
StructField("MAP", MapType(TimestampType, StringType), true) ::
StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 14e7b4a956..88b221cd81 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -185,7 +185,7 @@ class DataTypeSuite extends SparkFunSuite {
checkDataTypeJsonRepr(FloatType)
checkDataTypeJsonRepr(DoubleType)
checkDataTypeJsonRepr(DecimalType(10, 5))
- checkDataTypeJsonRepr(DecimalType.Unlimited)
+ checkDataTypeJsonRepr(DecimalType.SYSTEM_DEFAULT)
checkDataTypeJsonRepr(DateType)
checkDataTypeJsonRepr(TimestampType)
checkDataTypeJsonRepr(StringType)
@@ -219,7 +219,7 @@ class DataTypeSuite extends SparkFunSuite {
checkDefaultSize(FloatType, 4)
checkDefaultSize(DoubleType, 8)
checkDefaultSize(DecimalType(10, 5), 4096)
- checkDefaultSize(DecimalType.Unlimited, 4096)
+ checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 4096)
checkDefaultSize(DateType, 4)
checkDefaultSize(TimestampType, 8)
checkDefaultSize(StringType, 4096)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
index 32632b5d6e..0ee9ddac81 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
@@ -34,7 +34,7 @@ object DataTypeTestUtils {
* decimal types.
*/
val fractionalTypes: Set[FractionalType] = Set(
- DecimalType(precisionInfo = None),
+ DecimalType.SYSTEM_DEFAULT,
DecimalType(2, 1),
DoubleType,
FloatType
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 6e2a6525bf..b25dcbca82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -996,7 +996,7 @@ class ColumnName(name: String) extends Column(name) {
* Creates a new [[StructField]] of type decimal.
* @since 1.3.0
*/
- def decimal: StructField = StructField(name, DecimalType.Unlimited)
+ def decimal: StructField = StructField(name, DecimalType.USER_DEFAULT)
/**
* Creates a new [[StructField]] of type decimal.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index fc72360c88..9d8415f063 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -375,7 +375,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) {
private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
extends NativeColumnType(
- DecimalType(Some(PrecisionInfo(precision, scale))),
+ DecimalType(precision, scale),
10,
FIXED_DECIMAL.defaultSize) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 16176abe3a..5ed158b3d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -21,9 +21,9 @@ import org.apache.spark.TaskContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.trees._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.trees._
import org.apache.spark.sql.types._
case class AggregateEvaluation(
@@ -92,8 +92,8 @@ case class GeneratedAggregate(
case s @ Sum(expr) =>
val calcType =
expr.dataType match {
- case DecimalType.Fixed(_, _) =>
- DecimalType.Unlimited
+ case DecimalType.Fixed(p, s) =>
+ DecimalType.bounded(p + 10, s)
case _ =>
expr.dataType
}
@@ -121,8 +121,8 @@ case class GeneratedAggregate(
case cs @ CombineSum(expr) =>
val calcType =
expr.dataType match {
- case DecimalType.Fixed(_, _) =>
- DecimalType.Unlimited
+ case DecimalType.Fixed(p, s) =>
+ DecimalType.bounded(p + 10, s)
case _ =>
expr.dataType
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index 6b4a359db2..9d0fa894b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -25,6 +25,7 @@ import scala.util.Try
import org.apache.hadoop.fs.Path
import org.apache.hadoop.util.Shell
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
import org.apache.spark.sql.types._
@@ -236,7 +237,7 @@ private[sql] object PartitioningUtils {
/**
* Converts a string to a [[Literal]] with automatic type inference. Currently only supports
- * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.Unlimited]], and
+ * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.SYSTEM_DEFAULT]], and
* [[StringType]].
*/
private[sql] def inferPartitionColumnValue(
@@ -249,7 +250,7 @@ private[sql] object PartitioningUtils {
.orElse(Try(Literal.create(JLong.parseLong(raw), LongType)))
// Then falls back to fractional types
.orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType)))
- .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited)))
+ .orElse(Try(Literal(new JBigDecimal(raw))))
// Then falls back to string
.getOrElse {
if (raw == defaultPartitionName) {
@@ -268,7 +269,7 @@ private[sql] object PartitioningUtils {
}
private val upCastingOrder: Seq[DataType] =
- Seq(NullType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited, StringType)
+ Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType)
/**
* Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index 7a27fba178..3cf70db6b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -66,8 +66,8 @@ private[sql] object JDBCRDD extends Logging {
case java.sql.Types.DATALINK => null
case java.sql.Types.DATE => DateType
case java.sql.Types.DECIMAL
- if precision != 0 || scale != 0 => DecimalType(precision, scale)
- case java.sql.Types.DECIMAL => DecimalType.Unlimited
+ if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
+ case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT
case java.sql.Types.DISTINCT => null
case java.sql.Types.DOUBLE => DoubleType
case java.sql.Types.FLOAT => FloatType
@@ -80,8 +80,8 @@ private[sql] object JDBCRDD extends Logging {
case java.sql.Types.NCLOB => StringType
case java.sql.Types.NULL => null
case java.sql.Types.NUMERIC
- if precision != 0 || scale != 0 => DecimalType(precision, scale)
- case java.sql.Types.NUMERIC => DecimalType.Unlimited
+ if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
+ case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT
case java.sql.Types.NVARCHAR => StringType
case java.sql.Types.OTHER => null
case java.sql.Types.REAL => DoubleType
@@ -314,7 +314,7 @@ private[sql] class JDBCRDD(
abstract class JDBCConversion
case object BooleanConversion extends JDBCConversion
case object DateConversion extends JDBCConversion
- case class DecimalConversion(precisionInfo: Option[(Int, Int)]) extends JDBCConversion
+ case class DecimalConversion(precision: Int, scale: Int) extends JDBCConversion
case object DoubleConversion extends JDBCConversion
case object FloatConversion extends JDBCConversion
case object IntegerConversion extends JDBCConversion
@@ -331,8 +331,7 @@ private[sql] class JDBCRDD(
schema.fields.map(sf => sf.dataType match {
case BooleanType => BooleanConversion
case DateType => DateConversion
- case DecimalType.Unlimited => DecimalConversion(None)
- case DecimalType.Fixed(d) => DecimalConversion(Some(d))
+ case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
case DoubleType => DoubleConversion
case FloatType => FloatConversion
case IntegerType => IntegerConversion
@@ -399,20 +398,13 @@ private[sql] class JDBCRDD(
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
// retrieve it, you will get wrong result 199.99.
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
- case DecimalConversion(Some((p, s))) =>
+ case DecimalConversion(p, s) =>
val decimalVal = rs.getBigDecimal(pos)
if (decimalVal == null) {
mutableRow.update(i, null)
} else {
mutableRow.update(i, Decimal(decimalVal, p, s))
}
- case DecimalConversion(None) =>
- val decimalVal = rs.getBigDecimal(pos)
- if (decimalVal == null) {
- mutableRow.update(i, null)
- } else {
- mutableRow.update(i, Decimal(decimalVal))
- }
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
index f7ea852fe7..035e051008 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
@@ -89,8 +89,7 @@ package object jdbc {
case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i))
case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
- case DecimalType.Unlimited => stmt.setBigDecimal(i + 1,
- row.getAs[java.math.BigDecimal](i))
+ case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
case _ => throw new IllegalArgumentException(
s"Can't translate non-null value for field $i")
}
@@ -145,7 +144,7 @@ package object jdbc {
case BinaryType => "BLOB"
case TimestampType => "TIMESTAMP"
case DateType => "DATE"
- case DecimalType.Unlimited => "DECIMAL(40,20)"
+ case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})"
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
})
val nullable = if (field.nullable) "" else "NOT NULL"
@@ -177,7 +176,7 @@ package object jdbc {
case BinaryType => java.sql.Types.BLOB
case TimestampType => java.sql.Types.TIMESTAMP
case DateType => java.sql.Types.DATE
- case DecimalType.Unlimited => java.sql.Types.DECIMAL
+ case t: DecimalType => java.sql.Types.DECIMAL
case _ => throw new IllegalArgumentException(
s"Can't translate null value for field $field")
})
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
index afe2c6c11a..0eb3b04007 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
@@ -113,7 +113,7 @@ private[sql] object InferSchema {
case INT | LONG => LongType
// Since we do not have a data type backed by BigInteger,
// when we see a Java BigInteger, we use DecimalType.
- case BIG_INTEGER | BIG_DECIMAL => DecimalType.Unlimited
+ case BIG_INTEGER | BIG_DECIMAL => DecimalType.SYSTEM_DEFAULT
case FLOAT | DOUBLE => DoubleType
}
@@ -168,8 +168,13 @@ private[sql] object InferSchema {
HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse {
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
(t1, t2) match {
- case (other: DataType, NullType) => other
- case (NullType, other: DataType) => other
+ // Double support larger range than fixed decimal, DecimalType.Maximum should be enough
+ // in most case, also have better precision.
+ case (DoubleType, t: DecimalType) =>
+ if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType
+ case (t: DecimalType, DoubleType) =>
+ if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType
+
case (StructType(fields1), StructType(fields2)) =>
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
case (name, fieldTypes) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala
index 1ea6926af6..1d3a0d15d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala
@@ -439,10 +439,6 @@ private[parquet] class CatalystSchemaConverter(
.length(minBytesForPrecision(precision))
.named(field.name)
- case dec @ DecimalType.Unlimited if followParquetFormatSpec =>
- throw new AnalysisException(
- s"Data type $dec is not supported. Decimal precision and scale must be specified.")
-
// ===================================================
// ArrayType and MapType (for Spark versions <= 1.4.x)
// ===================================================
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index e8851ddb68..d1040bf556 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -261,10 +261,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
case BinaryType => writer.addBinary(
Binary.fromByteArray(value.asInstanceOf[Array[Byte]]))
case d: DecimalType =>
- if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
+ if (d.precision > 18) {
sys.error(s"Unsupported datatype $d, cannot write to consumer")
}
- writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision)
+ writeDecimal(value.asInstanceOf[Decimal], d.precision)
case _ => sys.error(s"Do not know how to writer $schema to consumer")
}
}
@@ -415,10 +415,10 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
case BinaryType => writer.addBinary(
Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]]))
case d: DecimalType =>
- if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
+ if (d.precision > 18) {
sys.error(s"Unsupported datatype $d, cannot write to consumer")
}
- writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision)
+ writeDecimal(record(index).asInstanceOf[Decimal], d.precision)
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
index fcb8f5499c..cb84e78d62 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
@@ -22,7 +22,6 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-import org.apache.spark.sql.test.TestSQLContext$;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
@@ -31,8 +30,14 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
-import org.apache.spark.sql.*;
-import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.test.TestSQLContext$;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
// The test suite itself is Serializable so that anonymous Function implementations can be
// serialized, as an alternative to converting these anonymous classes to static inner classes;
@@ -159,7 +164,8 @@ public class JavaApplySchemaSuite implements Serializable {
"\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " +
"\"boolean\":false, \"null\":null}"));
List<StructField> fields = new ArrayList<StructField>(7);
- fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true));
+ fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(38, 18),
+ true));
fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true));
fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true));
fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true));
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 01bc23277f..037e2048a8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -148,7 +148,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
val dataTypes =
Seq(StringType, BinaryType, NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
- FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5),
+ FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5),
DateType, TimestampType,
ArrayType(IntegerType), MapType(StringType, LongType), struct)
val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 3d71deb13e..845ce669f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -109,7 +109,7 @@ class PlannerSuite extends SparkFunSuite {
FloatType ::
DoubleType ::
DecimalType(10, 5) ::
- DecimalType.Unlimited ::
+ DecimalType.SYSTEM_DEFAULT ::
DateType ::
TimestampType ::
StringType ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
index 4a53fadd7e..54f82f89ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
@@ -54,7 +54,7 @@ class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite {
checkSupported(StringType, isSupported = true)
checkSupported(BinaryType, isSupported = true)
checkSupported(DecimalType(10, 5), isSupported = true)
- checkSupported(DecimalType.Unlimited, isSupported = true)
+ checkSupported(DecimalType.SYSTEM_DEFAULT, isSupported = true)
// If NullType is the only data type in the schema, we do not support it.
checkSupported(NullType, isSupported = false)
@@ -86,7 +86,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
val supportedTypes =
Seq(StringType, BinaryType, NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
- FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5),
+ FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5),
DateType, TimestampType)
val fields = supportedTypes.zipWithIndex.map { case (dataType, index) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 0f82f13088..42f2449afb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -134,7 +134,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
""".stripMargin.replaceAll("\n", " "))
- conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(40, 20))"
+ conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(38, 18))"
).executeUpdate()
conn.prepareStatement("insert into test.flttypes values ("
+ "1.0000000000000002220446049250313080847263336181640625, "
@@ -152,7 +152,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
s"""
|create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20),
|f VARCHAR_IGNORECASE(20), g CHAR(20), h BLOB, i CLOB, j TIME, k DATE, l TIMESTAMP,
- |m DOUBLE, n REAL, o DECIMAL(40, 20))
+ |m DOUBLE, n REAL, o DECIMAL(38, 18))
""".stripMargin.replaceAll("\n", " ")).executeUpdate()
conn.prepareStatement("insert into test.nulltypes values ("
+ "null, null, null, null, null, null, null, null, null, "
@@ -357,14 +357,14 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
test("H2 floating-point types") {
val rows = sql("SELECT * FROM flttypes").collect()
- assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==.
- assert(rows(0).getDouble(1) === 1.00000011920928955) // Yes, I meant ==.
- assert(rows(0).getAs[BigDecimal](2)
- .equals(new BigDecimal("123456789012345.54321543215432100000")))
- assert(rows(0).schema.fields(2).dataType === DecimalType(40, 20))
- val compareDecimal = sql("SELECT C FROM flttypes where C > C - 1").collect()
- assert(compareDecimal(0).getAs[BigDecimal](0)
- .equals(new BigDecimal("123456789012345.54321543215432100000")))
+ assert(rows(0).getDouble(0) === 1.00000000000000022)
+ assert(rows(0).getDouble(1) === 1.00000011920928955)
+ assert(rows(0).getAs[BigDecimal](2) ===
+ new BigDecimal("123456789012345.543215432154321000"))
+ assert(rows(0).schema.fields(2).dataType === DecimalType(38, 18))
+ val result = sql("SELECT C FROM flttypes where C > C - 1").collect()
+ assert(result(0).getAs[BigDecimal](0) ===
+ new BigDecimal("123456789012345.543215432154321000"))
}
test("SQL query as table name") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 1d04513a44..3ac312d6f4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -63,18 +63,18 @@ class JsonSuite extends QueryTest with TestJsonData {
checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType))
checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType))
checkTypePromotion(
- Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited))
+ Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.SYSTEM_DEFAULT))
val longNumber: Long = 9223372036854775807L
checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType))
checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType))
checkTypePromotion(
- Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited))
+ Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.SYSTEM_DEFAULT))
val doubleNumber: Double = 1.7976931348623157E308d
checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType))
checkTypePromotion(
- Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited))
+ Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.SYSTEM_DEFAULT))
checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)),
enforceCorrectType(intNumber, TimestampType))
@@ -115,7 +115,7 @@ class JsonSuite extends QueryTest with TestJsonData {
checkDataType(NullType, IntegerType, IntegerType)
checkDataType(NullType, LongType, LongType)
checkDataType(NullType, DoubleType, DoubleType)
- checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited)
+ checkDataType(NullType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT)
checkDataType(NullType, StringType, StringType)
checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType))
checkDataType(NullType, StructType(Nil), StructType(Nil))
@@ -126,7 +126,7 @@ class JsonSuite extends QueryTest with TestJsonData {
checkDataType(BooleanType, IntegerType, StringType)
checkDataType(BooleanType, LongType, StringType)
checkDataType(BooleanType, DoubleType, StringType)
- checkDataType(BooleanType, DecimalType.Unlimited, StringType)
+ checkDataType(BooleanType, DecimalType.SYSTEM_DEFAULT, StringType)
checkDataType(BooleanType, StringType, StringType)
checkDataType(BooleanType, ArrayType(IntegerType), StringType)
checkDataType(BooleanType, StructType(Nil), StringType)
@@ -135,7 +135,7 @@ class JsonSuite extends QueryTest with TestJsonData {
checkDataType(IntegerType, IntegerType, IntegerType)
checkDataType(IntegerType, LongType, LongType)
checkDataType(IntegerType, DoubleType, DoubleType)
- checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited)
+ checkDataType(IntegerType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT)
checkDataType(IntegerType, StringType, StringType)
checkDataType(IntegerType, ArrayType(IntegerType), StringType)
checkDataType(IntegerType, StructType(Nil), StringType)
@@ -143,23 +143,24 @@ class JsonSuite extends QueryTest with TestJsonData {
// LongType
checkDataType(LongType, LongType, LongType)
checkDataType(LongType, DoubleType, DoubleType)
- checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited)
+ checkDataType(LongType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT)
checkDataType(LongType, StringType, StringType)
checkDataType(LongType, ArrayType(IntegerType), StringType)
checkDataType(LongType, StructType(Nil), StringType)
// DoubleType
checkDataType(DoubleType, DoubleType, DoubleType)
- checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited)
+ checkDataType(DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT)
checkDataType(DoubleType, StringType, StringType)
checkDataType(DoubleType, ArrayType(IntegerType), StringType)
checkDataType(DoubleType, StructType(Nil), StringType)
- // DoubleType
- checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited)
- checkDataType(DecimalType.Unlimited, StringType, StringType)
- checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType)
- checkDataType(DecimalType.Unlimited, StructType(Nil), StringType)
+ // DecimalType
+ checkDataType(DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT,
+ DecimalType.SYSTEM_DEFAULT)
+ checkDataType(DecimalType.SYSTEM_DEFAULT, StringType, StringType)
+ checkDataType(DecimalType.SYSTEM_DEFAULT, ArrayType(IntegerType), StringType)
+ checkDataType(DecimalType.SYSTEM_DEFAULT, StructType(Nil), StringType)
// StringType
checkDataType(StringType, StringType, StringType)
@@ -213,7 +214,7 @@ class JsonSuite extends QueryTest with TestJsonData {
checkDataType(
StructType(
StructField("f1", IntegerType, true) :: Nil),
- DecimalType.Unlimited,
+ DecimalType.SYSTEM_DEFAULT,
StringType)
}
@@ -240,7 +241,7 @@ class JsonSuite extends QueryTest with TestJsonData {
val jsonDF = ctx.read.json(primitiveFieldAndType)
val expectedSchema = StructType(
- StructField("bigInteger", DecimalType.Unlimited, true) ::
+ StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", LongType, true) ::
@@ -270,7 +271,7 @@ class JsonSuite extends QueryTest with TestJsonData {
val expectedSchema = StructType(
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) ::
StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, true), true), true) ::
- StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, true), true) ::
+ StructField("arrayOfBigInteger", ArrayType(DecimalType.SYSTEM_DEFAULT, true), true) ::
StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) ::
StructField("arrayOfDouble", ArrayType(DoubleType, true), true) ::
StructField("arrayOfInteger", ArrayType(LongType, true), true) ::
@@ -284,7 +285,7 @@ class JsonSuite extends QueryTest with TestJsonData {
StructField("field3", StringType, true) :: Nil), true), true) ::
StructField("struct", StructType(
StructField("field1", BooleanType, true) ::
- StructField("field2", DecimalType.Unlimited, true) :: Nil), true) ::
+ StructField("field2", DecimalType.SYSTEM_DEFAULT, true) :: Nil), true) ::
StructField("structWithArrayFields", StructType(
StructField("field1", ArrayType(LongType, true), true) ::
StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil)
@@ -385,7 +386,7 @@ class JsonSuite extends QueryTest with TestJsonData {
val expectedSchema = StructType(
StructField("num_bool", StringType, true) ::
StructField("num_num_1", LongType, true) ::
- StructField("num_num_2", DecimalType.Unlimited, true) ::
+ StructField("num_num_2", DecimalType.SYSTEM_DEFAULT, true) ::
StructField("num_num_3", DoubleType, true) ::
StructField("num_str", StringType, true) ::
StructField("str_bool", StringType, true) :: Nil)
@@ -421,11 +422,11 @@ class JsonSuite extends QueryTest with TestJsonData {
Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil
)
- // Widening to DecimalType
+ // Widening to DoubleType
checkAnswer(
- sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"),
- Row(new java.math.BigDecimal("21474836472.1")) ::
- Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil
+ sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"),
+ Row(21474836472.2) ::
+ Row(92233720368547758071.3) :: Nil
)
// Widening to DoubleType
@@ -442,8 +443,8 @@ class JsonSuite extends QueryTest with TestJsonData {
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
- sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"),
- Row(new java.math.BigDecimal("92233720368547758061.2").doubleValue)
+ sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"),
+ Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue)
)
// String and Boolean conflict: resolve the type as string.
@@ -489,9 +490,9 @@ class JsonSuite extends QueryTest with TestJsonData {
// in the Project.
checkAnswer(
jsonDF.
- where('num_str > BigDecimal("92233720368547758060")).
+ where('num_str >= BigDecimal("92233720368547758060")).
select(('num_str + 1.2).as("num")),
- Row(new java.math.BigDecimal("92233720368547758061.2"))
+ Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue())
)
// The following test will fail. The type of num_str is StringType.
@@ -610,7 +611,7 @@ class JsonSuite extends QueryTest with TestJsonData {
val jsonDF = ctx.read.json(path)
val expectedSchema = StructType(
- StructField("bigInteger", DecimalType.Unlimited, true) ::
+ StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", LongType, true) ::
@@ -668,7 +669,7 @@ class JsonSuite extends QueryTest with TestJsonData {
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
val schema = StructType(
- StructField("bigInteger", DecimalType.Unlimited, true) ::
+ StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", IntegerType, true) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index 7b16eba00d..3a5b860484 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -122,14 +122,6 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
sqlContext.read.parquet(dir.getCanonicalPath).collect()
}
}
-
- // Unlimited-length decimals are not yet supported
- intercept[Throwable] {
- withTempPath { dir =>
- makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath)
- sqlContext.read.parquet(dir.getCanonicalPath).collect()
- }
- }
}
test("date type") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
index 4f98776b91..7f16b1125c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
@@ -509,7 +509,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
FloatType,
DoubleType,
DecimalType(10, 5),
- DecimalType.Unlimited,
+ DecimalType.SYSTEM_DEFAULT,
DateType,
TimestampType,
StringType)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
index 54e1efb6e3..da53ec16b5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
@@ -44,7 +44,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
StructField("doubleType", DoubleType, nullable = false),
StructField("bigintType", LongType, nullable = false),
StructField("tinyintType", ByteType, nullable = false),
- StructField("decimalType", DecimalType.Unlimited, nullable = false),
+ StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false),
StructField("fixedDecimalType", DecimalType(5, 1), nullable = false),
StructField("binaryType", BinaryType, nullable = false),
StructField("booleanType", BooleanType, nullable = false),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 2c916f3322..143aadc08b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -202,7 +202,7 @@ class TableScanSuite extends DataSourceTest {
StructField("longField_:,<>=+/~^", LongType, true) ::
StructField("floatField", FloatType, true) ::
StructField("doubleField", DoubleType, true) ::
- StructField("decimalField1", DecimalType.Unlimited, true) ::
+ StructField("decimalField1", DecimalType.USER_DEFAULT, true) ::
StructField("decimalField2", DecimalType(9, 2), true) ::
StructField("dateField", DateType, true) ::
StructField("timestampField", TimestampType, true) ::
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index a8f2ee37cb..592cfa0ee8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -179,7 +179,7 @@ private[hive] trait HiveInspectors {
// writable
case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
- case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.Unlimited
+ case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.SYSTEM_DEFAULT
case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType
@@ -195,8 +195,8 @@ private[hive] trait HiveInspectors {
case c: Class[_] if c == classOf[java.lang.String] => StringType
case c: Class[_] if c == classOf[java.sql.Date] => DateType
case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
- case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.Unlimited
- case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.Unlimited
+ case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.SYSTEM_DEFAULT
+ case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.SYSTEM_DEFAULT
case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
case c: Class[_] if c == classOf[java.lang.Short] => ShortType
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
@@ -813,9 +813,6 @@ private[hive] trait HiveInspectors {
private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match {
case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale)
- case _ => new DecimalTypeInfo(
- HiveShim.UNLIMITED_DECIMAL_PRECISION,
- HiveShim.UNLIMITED_DECIMAL_SCALE)
}
def toTypeInfo: TypeInfo = dt match {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 8518e333e8..620b8a44d8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -377,7 +377,7 @@ private[hive] object HiveQl extends Logging {
DecimalType(precision.getText.toInt, scale.getText.toInt)
case Token("TOK_DECIMAL", precision :: Nil) =>
DecimalType(precision.getText.toInt, 0)
- case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited
+ case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT
case Token("TOK_BIGINT", Nil) => LongType
case Token("TOK_INT", Nil) => IntegerType
case Token("TOK_TINYINT", Nil) => ByteType
@@ -1369,7 +1369,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0))
case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DecimalType.Unlimited)
+ Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT)
case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), TimestampType)
case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) =>