aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@appier.com>2015-07-19 20:53:18 -0700
committerReynold Xin <rxin@databricks.com>2015-07-19 20:53:18 -0700
commitd743bec645fd2a65bd488d2d660b3aa2135b4da6 (patch)
treef8659fb3010353d93063b153e4463502b4051529
parent93eb2acfb287807355ba5d77989d239fdd6e2c30 (diff)
downloadspark-d743bec645fd2a65bd488d2d660b3aa2135b4da6.tar.gz
spark-d743bec645fd2a65bd488d2d660b3aa2135b4da6.tar.bz2
spark-d743bec645fd2a65bd488d2d660b3aa2135b4da6.zip
[SPARK-9172][SQL] Make DecimalPrecision support for Intersect and Except
JIRA: https://issues.apache.org/jira/browse/SPARK-9172 Simply make `DecimalPrecision` support for `Intersect` and `Except` in addition to `Union`. Besides, add unit test for `DecimalPrecision` as well. Author: Liang-Chi Hsieh <viirya@appier.com> Closes #7511 from viirya/more_decimalprecieion and squashes the following commits: 4d29d10 [Liang-Chi Hsieh] Fix code comment. 9fb0d49 [Liang-Chi Hsieh] Make DecimalPrecision support for Intersect and Except.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala86
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala55
2 files changed, 104 insertions, 37 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index ff20835e82..e214545726 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -335,7 +335,7 @@ object HiveTypeCoercion {
* - INT gets turned into DECIMAL(10, 0)
* - LONG gets turned into DECIMAL(20, 0)
* - FLOAT and DOUBLE
- * 1. Union operation:
+ * 1. Union, Intersect and Except operations:
* FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the
* same as Hive)
* 2. Other operation:
@@ -362,47 +362,59 @@ object HiveTypeCoercion {
DoubleType -> DecimalType(15, 15)
)
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- // fix decimal precision for union
- case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
- val castedInput = left.output.zip(right.output).map {
- case (lhs, rhs) if lhs.dataType != rhs.dataType =>
- (lhs.dataType, rhs.dataType) match {
- case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
- // Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to
- // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
- val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2))
- (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)())
- case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
- (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs)
- case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
- (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)())
- case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) =>
- (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs)
- case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) =>
- (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)())
- case _ => (lhs, rhs)
- }
- case other => other
- }
+ private def castDecimalPrecision(
+ left: LogicalPlan,
+ right: LogicalPlan): (LogicalPlan, LogicalPlan) = {
+ val castedInput = left.output.zip(right.output).map {
+ case (lhs, rhs) if lhs.dataType != rhs.dataType =>
+ (lhs.dataType, rhs.dataType) match {
+ case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
+ // Decimals with precision/scale p1/s2 and p2/s2 will be promoted to
+ // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
+ val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2))
+ (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)())
+ case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
+ (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs)
+ case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
+ (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)())
+ case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) =>
+ (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs)
+ case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) =>
+ (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)())
+ case _ => (lhs, rhs)
+ }
+ case other => other
+ }
- val (castedLeft, castedRight) = castedInput.unzip
+ val (castedLeft, castedRight) = castedInput.unzip
- val newLeft =
- if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
- Project(castedLeft, left)
- } else {
- left
- }
+ val newLeft =
+ if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
+ Project(castedLeft, left)
+ } else {
+ left
+ }
- val newRight =
- if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
- Project(castedRight, right)
- } else {
- right
- }
+ val newRight =
+ if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
+ Project(castedRight, right)
+ } else {
+ right
+ }
+ (newLeft, newRight)
+ }
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ // fix decimal precision for union, intersect and except
+ case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
+ val (newLeft, newRight) = castDecimalPrecision(left, right)
Union(newLeft, newRight)
+ case i @ Intersect(left, right) if i.childrenResolved && !i.resolved =>
+ val (newLeft, newRight) = castDecimalPrecision(left, right)
+ Intersect(newLeft, newRight)
+ case e @ Except(left, right) if e.childrenResolved && !e.resolved =>
+ val (newLeft, newRight) = castDecimalPrecision(left, right)
+ Except(newLeft, newRight)
// fix decimal precision for expressions
case q => q.transformExpressions {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 7ee2333a81..835220c563 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -346,6 +346,61 @@ class HiveTypeCoercionSuite extends PlanTest {
checkOutput(r3.right, expectedTypes)
}
+ test("Transform Decimal precision/scale for union except and intersect") {
+ def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
+ logical.output.zip(expectTypes).foreach { case (attr, dt) =>
+ assert(attr.dataType === dt)
+ }
+ }
+
+ val dp = HiveTypeCoercion.DecimalPrecision
+
+ val left1 = LocalRelation(
+ AttributeReference("l", DecimalType(10, 8))())
+ val right1 = LocalRelation(
+ AttributeReference("r", DecimalType(5, 5))())
+ val expectedType1 = Seq(DecimalType(math.max(8, 5) + math.max(10 - 8, 5 - 5), math.max(8, 5)))
+
+ val r1 = dp(Union(left1, right1)).asInstanceOf[Union]
+ val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
+ val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect]
+
+ checkOutput(r1.left, expectedType1)
+ checkOutput(r1.right, expectedType1)
+ checkOutput(r2.left, expectedType1)
+ checkOutput(r2.right, expectedType1)
+ checkOutput(r3.left, expectedType1)
+ checkOutput(r3.right, expectedType1)
+
+ val plan1 = LocalRelation(
+ AttributeReference("l", DecimalType(10, 10))())
+
+ val rightTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
+ val expectedTypes = Seq(DecimalType(3, 0), DecimalType(5, 0), DecimalType(10, 0),
+ DecimalType(20, 0), DecimalType(7, 7), DecimalType(15, 15))
+
+ rightTypes.zip(expectedTypes).map { case (rType, expectedType) =>
+ val plan2 = LocalRelation(
+ AttributeReference("r", rType)())
+
+ val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union]
+ val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except]
+ val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect]
+
+ checkOutput(r1.right, Seq(expectedType))
+ checkOutput(r2.right, Seq(expectedType))
+ checkOutput(r3.right, Seq(expectedType))
+
+ val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union]
+ val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except]
+ val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect]
+
+ checkOutput(r4.left, Seq(expectedType))
+ checkOutput(r5.left, Seq(expectedType))
+ checkOutput(r6.left, Seq(expectedType))
+ }
+ }
+
/**
* There are rules that need to not fire before child expressions get resolved.
* We use this test to make sure those rules do not fire early.