From d743bec645fd2a65bd488d2d660b3aa2135b4da6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 19 Jul 2015 20:53:18 -0700 Subject: [SPARK-9172][SQL] Make DecimalPrecision support for Intersect and Except JIRA: https://issues.apache.org/jira/browse/SPARK-9172 Simply make `DecimalPrecision` support for `Intersect` and `Except` in addition to `Union`. Besides, add unit test for `DecimalPrecision` as well. Author: Liang-Chi Hsieh Closes #7511 from viirya/more_decimalprecieion and squashes the following commits: 4d29d10 [Liang-Chi Hsieh] Fix code comment. 9fb0d49 [Liang-Chi Hsieh] Make DecimalPrecision support for Intersect and Except. --- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 86 ++++++++++++---------- .../catalyst/analysis/HiveTypeCoercionSuite.scala | 55 ++++++++++++++ 2 files changed, 104 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index ff20835e82..e214545726 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -335,7 +335,7 @@ object HiveTypeCoercion { * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) * - FLOAT and DOUBLE - * 1. Union operation: + * 1. Union, Intersect and Except operations: * FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the * same as Hive) * 2. Other operation: @@ -362,47 +362,59 @@ object HiveTypeCoercion { DoubleType -> DecimalType(15, 15) ) - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // fix decimal precision for union - case u @ Union(left, right) if u.childrenResolved && !u.resolved => - val castedInput = left.output.zip(right.output).map { - case (lhs, rhs) if lhs.dataType != rhs.dataType => - (lhs.dataType, rhs.dataType) match { - case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => - // Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to - // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) - val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) - (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)()) - case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs) - case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)()) - case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => - (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs) - case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => - (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)()) - case _ => (lhs, rhs) - } - case other => other - } + private def castDecimalPrecision( + left: LogicalPlan, + right: LogicalPlan): (LogicalPlan, LogicalPlan) = { + val castedInput = left.output.zip(right.output).map { + case (lhs, rhs) if lhs.dataType != rhs.dataType => + (lhs.dataType, rhs.dataType) match { + case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => + // Decimals with precision/scale p1/s2 and p2/s2 will be promoted to + // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) + val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) + (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)()) + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)()) + case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => + (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs) + case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => + (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)()) + case _ => (lhs, rhs) + } + case other => other + } - val (castedLeft, castedRight) = castedInput.unzip + val (castedLeft, castedRight) = castedInput.unzip - val newLeft = - if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - Project(castedLeft, left) - } else { - left - } + val newLeft = + if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { + Project(castedLeft, left) + } else { + left + } - val newRight = - if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - Project(castedRight, right) - } else { - right - } + val newRight = + if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { + Project(castedRight, right) + } else { + right + } + (newLeft, newRight) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // fix decimal precision for union, intersect and except + case u @ Union(left, right) if u.childrenResolved && !u.resolved => + val (newLeft, newRight) = castDecimalPrecision(left, right) Union(newLeft, newRight) + case i @ Intersect(left, right) if i.childrenResolved && !i.resolved => + val (newLeft, newRight) = castDecimalPrecision(left, right) + Intersect(newLeft, newRight) + case e @ Except(left, right) if e.childrenResolved && !e.resolved => + val (newLeft, newRight) = castDecimalPrecision(left, right) + Except(newLeft, newRight) // fix decimal precision for expressions case q => q.transformExpressions { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 7ee2333a81..835220c563 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -346,6 +346,61 @@ class HiveTypeCoercionSuite extends PlanTest { checkOutput(r3.right, expectedTypes) } + test("Transform Decimal precision/scale for union except and intersect") { + def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { + logical.output.zip(expectTypes).foreach { case (attr, dt) => + assert(attr.dataType === dt) + } + } + + val dp = HiveTypeCoercion.DecimalPrecision + + val left1 = LocalRelation( + AttributeReference("l", DecimalType(10, 8))()) + val right1 = LocalRelation( + AttributeReference("r", DecimalType(5, 5))()) + val expectedType1 = Seq(DecimalType(math.max(8, 5) + math.max(10 - 8, 5 - 5), math.max(8, 5))) + + val r1 = dp(Union(left1, right1)).asInstanceOf[Union] + val r2 = dp(Except(left1, right1)).asInstanceOf[Except] + val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect] + + checkOutput(r1.left, expectedType1) + checkOutput(r1.right, expectedType1) + checkOutput(r2.left, expectedType1) + checkOutput(r2.right, expectedType1) + checkOutput(r3.left, expectedType1) + checkOutput(r3.right, expectedType1) + + val plan1 = LocalRelation( + AttributeReference("l", DecimalType(10, 10))()) + + val rightTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType) + val expectedTypes = Seq(DecimalType(3, 0), DecimalType(5, 0), DecimalType(10, 0), + DecimalType(20, 0), DecimalType(7, 7), DecimalType(15, 15)) + + rightTypes.zip(expectedTypes).map { case (rType, expectedType) => + val plan2 = LocalRelation( + AttributeReference("r", rType)()) + + val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union] + val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except] + val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect] + + checkOutput(r1.right, Seq(expectedType)) + checkOutput(r2.right, Seq(expectedType)) + checkOutput(r3.right, Seq(expectedType)) + + val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union] + val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except] + val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect] + + checkOutput(r4.left, Seq(expectedType)) + checkOutput(r5.left, Seq(expectedType)) + checkOutput(r6.left, Seq(expectedType)) + } + } + /** * There are rules that need to not fire before child expressions get resolved. * We use this test to make sure those rules do not fire early. -- cgit v1.2.3