diff options
author | Yijie Shen <henry.yijieshen@gmail.com> | 2015-07-18 12:57:53 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-18 12:57:53 -0700 |
commit | 3d2134fc0d90379b89da08de7614aef1ac674b1b (patch) | |
tree | 501304881704016c2c02bde6a052569af3884ffd /sql | |
parent | cdc36eef4160dbae32e19a1eadbb4cf062f2fb2b (diff) | |
download | spark-3d2134fc0d90379b89da08de7614aef1ac674b1b.tar.gz spark-3d2134fc0d90379b89da08de7614aef1ac674b1b.tar.bz2 spark-3d2134fc0d90379b89da08de7614aef1ac674b1b.zip |
[SPARK-9055][SQL] WidenTypes should also support Intersect and Except
JIRA: https://issues.apache.org/jira/browse/SPARK-9055
cc rxin
Author: Yijie Shen <henry.yijieshen@gmail.com>
Closes #7491 from yijieshen/widen and squashes the following commits:
079fa52 [Yijie Shen] widenType support for intersect and expect
Diffstat (limited to 'sql')
3 files changed, 94 insertions, 41 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 50db7d21f0..ff20835e82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -168,52 +168,65 @@ object HiveTypeCoercion { * - LongType to DoubleType */ object WidenTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // TODO: unions with fixed-precision decimals - case u @ Union(left, right) if u.childrenResolved && !u.resolved => - val castedInput = left.output.zip(right.output).map { - // When a string is found on one side, make the other side a string too. - case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType => - (lhs, Alias(Cast(rhs, StringType), rhs.name)()) - case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType => - (Alias(Cast(lhs, StringType), lhs.name)(), rhs) - case (lhs, rhs) if lhs.dataType != rhs.dataType => - logDebug(s"Resolving mismatched union input ${lhs.dataType}, ${rhs.dataType}") - findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType => - val newLeft = - if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)() - val newRight = - if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)() - - (newLeft, newRight) - }.getOrElse { - // If there is no applicable conversion, leave expression unchanged. - (lhs, rhs) - } + private[this] def widenOutputTypes(planName: String, left: LogicalPlan, right: LogicalPlan): + (LogicalPlan, LogicalPlan) = { + + // TODO: with fixed-precision decimals + val castedInput = left.output.zip(right.output).map { + // When a string is found on one side, make the other side a string too. + case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType => + (lhs, Alias(Cast(rhs, StringType), rhs.name)()) + case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType => + (Alias(Cast(lhs, StringType), lhs.name)(), rhs) + + case (lhs, rhs) if lhs.dataType != rhs.dataType => + logDebug(s"Resolving mismatched $planName input ${lhs.dataType}, ${rhs.dataType}") + findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType => + val newLeft = + if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)() + val newRight = + if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)() + + (newLeft, newRight) + }.getOrElse { + // If there is no applicable conversion, leave expression unchanged. + (lhs, rhs) + } - case other => other - } + case other => other + } - val (castedLeft, castedRight) = castedInput.unzip + val (castedLeft, castedRight) = castedInput.unzip - val newLeft = - if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - logDebug(s"Widening numeric types in union $castedLeft ${left.output}") - Project(castedLeft, left) - } else { - left - } + val newLeft = + if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { + logDebug(s"Widening numeric types in $planName $castedLeft ${left.output}") + Project(castedLeft, left) + } else { + left + } - val newRight = - if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - logDebug(s"Widening numeric types in union $castedRight ${right.output}") - Project(castedRight, right) - } else { - right - } + val newRight = + if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { + logDebug(s"Widening numeric types in $planName $castedRight ${right.output}") + Project(castedRight, right) + } else { + right + } + (newLeft, newRight) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case u @ Union(left, right) if u.childrenResolved && !u.resolved => + val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right) Union(newLeft, newRight) + case e @ Except(left, right) if e.childrenResolved && !e.resolved => + val (newLeft, newRight) = widenOutputTypes(e.nodeName, left, right) + Except(newLeft, newRight) + case i @ Intersect(left, right) if i.childrenResolved && !i.resolved => + val (newLeft, newRight) = widenOutputTypes(i.nodeName, left, right) + Intersect(newLeft, newRight) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 17a9124732..986c315b31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -141,6 +141,10 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output + + override lazy val resolved: Boolean = + childrenResolved && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } } case class InsertIntoTable( @@ -437,4 +441,8 @@ case object OneRowRelation extends LeafNode { case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output + + override lazy val resolved: Boolean = + childrenResolved && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index d0fd033b98..c9b3c69c6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -305,6 +305,38 @@ class HiveTypeCoercionSuite extends PlanTest { ) } + test("WidenTypes for union except and intersect") { + def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { + logical.output.zip(expectTypes).foreach { case (attr, dt) => + assert(attr.dataType === dt) + } + } + + val left = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("u", DecimalType.Unlimited)(), + AttributeReference("b", ByteType)(), + AttributeReference("d", DoubleType)()) + val right = LocalRelation( + AttributeReference("s", StringType)(), + AttributeReference("d", DecimalType(2, 1))(), + AttributeReference("f", FloatType)(), + AttributeReference("l", LongType)()) + + val wt = HiveTypeCoercion.WidenTypes + val expectedTypes = Seq(StringType, DecimalType.Unlimited, FloatType, DoubleType) + + val r1 = wt(Union(left, right)).asInstanceOf[Union] + val r2 = wt(Except(left, right)).asInstanceOf[Except] + val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect] + checkOutput(r1.left, expectedTypes) + checkOutput(r1.right, expectedTypes) + checkOutput(r2.left, expectedTypes) + checkOutput(r2.right, expectedTypes) + checkOutput(r3.left, expectedTypes) + checkOutput(r3.right, expectedTypes) + } + /** * There are rules that need to not fire before child expressions get resolved. * We use this test to make sure those rules do not fire early. |