diff options
author | guowei2 <guowei2@asiainfo.com> | 2015-04-04 02:02:30 +0800 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2015-04-04 02:02:30 +0800 |
commit | c23ba81b8cf86c3a085de8ddfef9403ff6fcd87f (patch) | |
tree | 9f42058a301732153b85ede6fed4d8316d2f2892 /sql | |
parent | dc6dff248d8f5d7de22af64b0586dfe3885731df (diff) | |
download | spark-c23ba81b8cf86c3a085de8ddfef9403ff6fcd87f.tar.gz spark-c23ba81b8cf86c3a085de8ddfef9403ff6fcd87f.tar.bz2 spark-c23ba81b8cf86c3a085de8ddfef9403ff6fcd87f.zip |
[SPARK-5203][SQL] fix union with different decimal type
When union non-decimal types with decimals, we use the following rules:
- FIRST `intTypeToFixed`, then fixed union decimals with precision/scale p1/s2 and p2/s2 will be promoted to
DecimalType(max(p1, p2), max(s1, s2))
- FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
but note that unlimited decimals are considered bigger than doubles in WidenTypes)
Author: guowei2 <guowei2@asiainfo.com>
Closes #4004 from guowei2/SPARK-5203 and squashes the following commits:
ff50f5f [guowei2] fix code style
11df1bf [guowei2] fix decimal union with double, double->Decimal(15,15)
0f345f9 [guowei2] fix structType merge with decimal
101ed4d [guowei2] fix build error after rebase
0b196e4 [guowei2] code style
fe2c2ca [guowei2] handle union decimal precision in 'DecimalPrecision'
421d840 [guowei2] fix union types for decimal precision
ef2c661 [guowei2] fix union with different decimal type
Diffstat (limited to 'sql')
4 files changed, 167 insertions, 69 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 9a33eb1452..3aeb964994 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -285,6 +285,7 @@ trait HiveTypeCoercion { * Calculates and propagates precision for fixed-precision decimals. Hive has a number of * rules for this based on the SQL standard and MS SQL: * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * https://msdn.microsoft.com/en-us/library/ms190476.aspx * * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 * respectively, then the following operations have the following precision / scale: @@ -296,6 +297,7 @@ trait HiveTypeCoercion { * e1 * e2 p1 + p2 + 1 s1 + s2 * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) * sum(e1) p1 + 10 s1 * avg(e1) p1 + 4 s1 + 4 * @@ -311,7 +313,12 @@ trait HiveTypeCoercion { * - SHORT gets turned into DECIMAL(5, 0) * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) - * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, + * - FLOAT and DOUBLE + * 1. Union operation: + * FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the + * same as Hive) + * 2. Other operation: + * FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, * but note that unlimited decimals are considered bigger than doubles in WidenTypes) */ // scalastyle:on @@ -328,76 +335,127 @@ trait HiveTypeCoercion { def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes whose children have not been resolved yet - case e if !e.childrenResolved => e + // Conversion rules for float and double into fixed-precision decimals + val floatTypeToFixed: Map[DataType, DecimalType] = Map( + FloatType -> DecimalType(7, 7), + DoubleType -> DecimalType(15, 15) + ) - case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) - - case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) - - case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 + p2 + 1, s1 + s2) - ) - - case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) - ) - - case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - ) - - case LessThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - // Promote integers inside a binary expression with fixed-precision decimals to decimals, - // and fixed-precision decimals in an expression with floats / doubles to doubles - case b: BinaryExpression if b.left.dataType != b.right.dataType => - (b.left.dataType, b.right.dataType) match { - case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) - case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) - case _ => - b + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // fix decimal precision for union + case u @ Union(left, right) if u.childrenResolved && !u.resolved => + val castedInput = left.output.zip(right.output).map { + case (l, r) if l.dataType != r.dataType => + (l.dataType, r.dataType) match { + case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => + // Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to + // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) + val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) + (Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)()) + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + (Alias(Cast(l, intTypeToFixed(t)), l.name)(), r) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + (l, Alias(Cast(r, intTypeToFixed(t)), r.name)()) + case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => + (Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r) + case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => + (l, Alias(Cast(r, floatTypeToFixed(t)), r.name)()) + case _ => (l, r) + } + case other => other } - // TODO: MaxOf, MinOf, etc might want other rules + val (castedLeft, castedRight) = castedInput.unzip - // SUM and AVERAGE are handled by the implementations of those expressions + val newLeft = + if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { + Project(castedLeft, left) + } else { + left + } + + val newRight = + if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { + Project(castedRight, right) + } else { + right + } + + Union(newLeft, newRight) + + // fix decimal precision for expressions + case q => q.transformExpressions { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 + p2 + 1, s1 + s2) + ) + + case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) + ) + + case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + + case LessThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b: BinaryExpression if b.left.dataType != b.right.dataType => + (b.left.dataType, b.right.dataType) match { + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) + case (t, DecimalType.Fixed(p, s)) if isFloat(t) => + b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) + case (DecimalType.Fixed(p, s), t) if isFloat(t) => + b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + case _ => + b + } + + // TODO: MaxOf, MinOf, etc might want other rules + + // SUM and AVERAGE are handled by the implementations of those expressions + } } + } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index 952cf5c756..cdf2bc68d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.types import java.sql.Timestamp import scala.collection.mutable.ArrayBuffer +import scala.math._ import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} @@ -934,7 +935,9 @@ object StructType { case (DecimalType.Fixed(leftPrecision, leftScale), DecimalType.Fixed(rightPrecision, rightScale)) => - DecimalType(leftPrecision.max(rightPrecision), leftScale.max(rightScale)) + DecimalType( + max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), + max(leftScale, rightScale)) case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) if leftUdt.userClass == rightUdt.userClass => leftUdt diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index bc2ec754d5..67bec999df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -31,7 +31,8 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { AttributeReference("d1", DecimalType(2, 1))(), AttributeReference("d2", DecimalType(5, 2))(), AttributeReference("u", DecimalType.Unlimited)(), - AttributeReference("f", FloatType)() + AttributeReference("f", FloatType)(), + AttributeReference("b", DoubleType)() ) val i: Expression = UnresolvedAttribute("i") @@ -39,6 +40,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { val d2: Expression = UnresolvedAttribute("d2") val u: Expression = UnresolvedAttribute("u") val f: Expression = UnresolvedAttribute("f") + val b: Expression = UnresolvedAttribute("b") before { catalog.registerTable(Seq("table"), relation) @@ -58,6 +60,17 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { assert(comparison.right.dataType === expectedType) } + private def checkUnion(left: Expression, right: Expression, expectedType: DataType): Unit = { + val plan = + Union(Project(Seq(Alias(left, "l")()), relation), + Project(Seq(Alias(right, "r")()), relation)) + val (l, r) = analyzer(plan).collect { + case Union(left, right) => (left.output.head, right.output.head) + }.head + assert(l.dataType === expectedType) + assert(r.dataType === expectedType) + } + test("basic operations") { checkType(Add(d1, d2), DecimalType(6, 2)) checkType(Subtract(d1, d2), DecimalType(6, 2)) @@ -82,6 +95,19 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) } + test("decimal precision for union") { + checkUnion(d1, i, DecimalType(11, 1)) + checkUnion(i, d2, DecimalType(12, 2)) + checkUnion(d1, d2, DecimalType(5, 2)) + checkUnion(d2, d1, DecimalType(5, 2)) + checkUnion(d1, f, DecimalType(8, 7)) + checkUnion(f, d2, DecimalType(10, 7)) + checkUnion(d1, b, DecimalType(16, 15)) + checkUnion(b, d2, DecimalType(18, 15)) + checkUnion(d1, u, DecimalType.Unlimited) + checkUnion(u, d2, DecimalType.Unlimited) + } + test("bringing in primitive types") { checkType(Add(d1, i), DecimalType(12, 1)) checkType(Add(d1, f), DoubleType) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2065f0d60d..817b9dcb8f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -468,4 +468,15 @@ class SQLQuerySuite extends QueryTest { sql(s"DROP TABLE $tableName") } } + + test("SPARK-5203 union with different decimal precision") { + Seq.empty[(Decimal, Decimal)] + .toDF("d1", "d2") + .select($"d1".cast(DecimalType(10, 15)).as("d")) + .registerTempTable("dn") + + sql("select d from dn union all select d * 2 from dn") + .queryExecution.analyzed + } + } |