aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorguowei2 <guowei2@asiainfo.com>2015-04-04 02:02:30 +0800
committerCheng Lian <lian@databricks.com>2015-04-04 02:02:30 +0800
commitc23ba81b8cf86c3a085de8ddfef9403ff6fcd87f (patch)
tree9f42058a301732153b85ede6fed4d8316d2f2892 /sql
parentdc6dff248d8f5d7de22af64b0586dfe3885731df (diff)
downloadspark-c23ba81b8cf86c3a085de8ddfef9403ff6fcd87f.tar.gz
spark-c23ba81b8cf86c3a085de8ddfef9403ff6fcd87f.tar.bz2
spark-c23ba81b8cf86c3a085de8ddfef9403ff6fcd87f.zip
[SPARK-5203][SQL] fix union with different decimal type
When union non-decimal types with decimals, we use the following rules: - FIRST `intTypeToFixed`, then fixed union decimals with precision/scale p1/s2 and p2/s2 will be promoted to DecimalType(max(p1, p2), max(s1, s2)) - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, but note that unlimited decimals are considered bigger than doubles in WidenTypes) Author: guowei2 <guowei2@asiainfo.com> Closes #4004 from guowei2/SPARK-5203 and squashes the following commits: ff50f5f [guowei2] fix code style 11df1bf [guowei2] fix decimal union with double, double->Decimal(15,15) 0f345f9 [guowei2] fix structType merge with decimal 101ed4d [guowei2] fix build error after rebase 0b196e4 [guowei2] code style fe2c2ca [guowei2] handle union decimal precision in 'DecimalPrecision' 421d840 [guowei2] fix union types for decimal precision ef2c661 [guowei2] fix union with different decimal type
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala190
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala30
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala11
4 files changed, 167 insertions, 69 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 9a33eb1452..3aeb964994 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -285,6 +285,7 @@ trait HiveTypeCoercion {
* Calculates and propagates precision for fixed-precision decimals. Hive has a number of
* rules for this based on the SQL standard and MS SQL:
* https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
+ * https://msdn.microsoft.com/en-us/library/ms190476.aspx
*
* In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
* respectively, then the following operations have the following precision / scale:
@@ -296,6 +297,7 @@ trait HiveTypeCoercion {
* e1 * e2 p1 + p2 + 1 s1 + s2
* e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1)
* e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2)
+ * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2)
* sum(e1) p1 + 10 s1
* avg(e1) p1 + 4 s1 + 4
*
@@ -311,7 +313,12 @@ trait HiveTypeCoercion {
* - SHORT gets turned into DECIMAL(5, 0)
* - INT gets turned into DECIMAL(10, 0)
* - LONG gets turned into DECIMAL(20, 0)
- * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
+ * - FLOAT and DOUBLE
+ * 1. Union operation:
+ * FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the
+ * same as Hive)
+ * 2. Other operation:
+ * FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
* but note that unlimited decimals are considered bigger than doubles in WidenTypes)
*/
// scalastyle:on
@@ -328,76 +335,127 @@ trait HiveTypeCoercion {
def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- // Skip nodes whose children have not been resolved yet
- case e if !e.childrenResolved => e
+ // Conversion rules for float and double into fixed-precision decimals
+ val floatTypeToFixed: Map[DataType, DecimalType] = Map(
+ FloatType -> DecimalType(7, 7),
+ DoubleType -> DecimalType(15, 15)
+ )
- case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- Cast(
- Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
- DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
- )
-
- case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- Cast(
- Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
- DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
- )
-
- case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- Cast(
- Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
- DecimalType(p1 + p2 + 1, s1 + s2)
- )
-
- case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- Cast(
- Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
- DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
- )
-
- case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- Cast(
- Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
- DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
- )
-
- case LessThan(e1 @ DecimalType.Expression(p1, s1),
- e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
- LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
-
- case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
- e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
- LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
-
- case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
- e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
- GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
-
- case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
- e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
- GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
-
- // Promote integers inside a binary expression with fixed-precision decimals to decimals,
- // and fixed-precision decimals in an expression with floats / doubles to doubles
- case b: BinaryExpression if b.left.dataType != b.right.dataType =>
- (b.left.dataType, b.right.dataType) match {
- case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
- b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right))
- case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
- b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t))))
- case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
- b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
- case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
- b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
- case _ =>
- b
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ // fix decimal precision for union
+ case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
+ val castedInput = left.output.zip(right.output).map {
+ case (l, r) if l.dataType != r.dataType =>
+ (l.dataType, r.dataType) match {
+ case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
+ // Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to
+ // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
+ val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2))
+ (Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)())
+ case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
+ (Alias(Cast(l, intTypeToFixed(t)), l.name)(), r)
+ case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
+ (l, Alias(Cast(r, intTypeToFixed(t)), r.name)())
+ case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) =>
+ (Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r)
+ case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) =>
+ (l, Alias(Cast(r, floatTypeToFixed(t)), r.name)())
+ case _ => (l, r)
+ }
+ case other => other
}
- // TODO: MaxOf, MinOf, etc might want other rules
+ val (castedLeft, castedRight) = castedInput.unzip
- // SUM and AVERAGE are handled by the implementations of those expressions
+ val newLeft =
+ if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
+ Project(castedLeft, left)
+ } else {
+ left
+ }
+
+ val newRight =
+ if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
+ Project(castedRight, right)
+ } else {
+ right
+ }
+
+ Union(newLeft, newRight)
+
+ // fix decimal precision for expressions
+ case q => q.transformExpressions {
+ // Skip nodes whose children have not been resolved yet
+ case e if !e.childrenResolved => e
+
+ case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
+ )
+
+ case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
+ )
+
+ case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(p1 + p2 + 1, s1 + s2)
+ )
+
+ case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
+ )
+
+ case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ )
+
+ case LessThan(e1 @ DecimalType.Expression(p1, s1),
+ e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+ LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
+
+ case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
+ e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+ LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
+
+ case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
+ e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+ GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
+
+ case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
+ e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+ GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
+
+ // Promote integers inside a binary expression with fixed-precision decimals to decimals,
+ // and fixed-precision decimals in an expression with floats / doubles to doubles
+ case b: BinaryExpression if b.left.dataType != b.right.dataType =>
+ (b.left.dataType, b.right.dataType) match {
+ case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
+ b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right))
+ case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
+ b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t))))
+ case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
+ b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
+ case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
+ b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
+ case _ =>
+ b
+ }
+
+ // TODO: MaxOf, MinOf, etc might want other rules
+
+ // SUM and AVERAGE are handled by the implementations of those expressions
+ }
}
+
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index 952cf5c756..cdf2bc68d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.types
import java.sql.Timestamp
import scala.collection.mutable.ArrayBuffer
+import scala.math._
import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral}
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag}
@@ -934,7 +935,9 @@ object StructType {
case (DecimalType.Fixed(leftPrecision, leftScale),
DecimalType.Fixed(rightPrecision, rightScale)) =>
- DecimalType(leftPrecision.max(rightPrecision), leftScale.max(rightScale))
+ DecimalType(
+ max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale),
+ max(leftScale, rightScale))
case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_])
if leftUdt.userClass == rightUdt.userClass => leftUdt
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index bc2ec754d5..67bec999df 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
+import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation}
import org.apache.spark.sql.types._
import org.scalatest.{BeforeAndAfter, FunSuite}
@@ -31,7 +31,8 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
AttributeReference("d1", DecimalType(2, 1))(),
AttributeReference("d2", DecimalType(5, 2))(),
AttributeReference("u", DecimalType.Unlimited)(),
- AttributeReference("f", FloatType)()
+ AttributeReference("f", FloatType)(),
+ AttributeReference("b", DoubleType)()
)
val i: Expression = UnresolvedAttribute("i")
@@ -39,6 +40,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
val d2: Expression = UnresolvedAttribute("d2")
val u: Expression = UnresolvedAttribute("u")
val f: Expression = UnresolvedAttribute("f")
+ val b: Expression = UnresolvedAttribute("b")
before {
catalog.registerTable(Seq("table"), relation)
@@ -58,6 +60,17 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
assert(comparison.right.dataType === expectedType)
}
+ private def checkUnion(left: Expression, right: Expression, expectedType: DataType): Unit = {
+ val plan =
+ Union(Project(Seq(Alias(left, "l")()), relation),
+ Project(Seq(Alias(right, "r")()), relation))
+ val (l, r) = analyzer(plan).collect {
+ case Union(left, right) => (left.output.head, right.output.head)
+ }.head
+ assert(l.dataType === expectedType)
+ assert(r.dataType === expectedType)
+ }
+
test("basic operations") {
checkType(Add(d1, d2), DecimalType(6, 2))
checkType(Subtract(d1, d2), DecimalType(6, 2))
@@ -82,6 +95,19 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
checkComparison(GreaterThan(d2, d2), DecimalType(5, 2))
}
+ test("decimal precision for union") {
+ checkUnion(d1, i, DecimalType(11, 1))
+ checkUnion(i, d2, DecimalType(12, 2))
+ checkUnion(d1, d2, DecimalType(5, 2))
+ checkUnion(d2, d1, DecimalType(5, 2))
+ checkUnion(d1, f, DecimalType(8, 7))
+ checkUnion(f, d2, DecimalType(10, 7))
+ checkUnion(d1, b, DecimalType(16, 15))
+ checkUnion(b, d2, DecimalType(18, 15))
+ checkUnion(d1, u, DecimalType.Unlimited)
+ checkUnion(u, d2, DecimalType.Unlimited)
+ }
+
test("bringing in primitive types") {
checkType(Add(d1, i), DecimalType(12, 1))
checkType(Add(d1, f), DoubleType)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 2065f0d60d..817b9dcb8f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -468,4 +468,15 @@ class SQLQuerySuite extends QueryTest {
sql(s"DROP TABLE $tableName")
}
}
+
+ test("SPARK-5203 union with different decimal precision") {
+ Seq.empty[(Decimal, Decimal)]
+ .toDF("d1", "d2")
+ .select($"d1".cast(DecimalType(10, 15)).as("d"))
+ .registerTempTable("dn")
+
+ sql("select d from dn union all select d * 2 from dn")
+ .queryExecution.analyzed
+ }
+
}