aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYijie Shen <henry.yijieshen@gmail.com>2015-07-18 12:57:53 -0700
committerReynold Xin <rxin@databricks.com>2015-07-18 12:57:53 -0700
commit3d2134fc0d90379b89da08de7614aef1ac674b1b (patch)
tree501304881704016c2c02bde6a052569af3884ffd
parentcdc36eef4160dbae32e19a1eadbb4cf062f2fb2b (diff)
downloadspark-3d2134fc0d90379b89da08de7614aef1ac674b1b.tar.gz
spark-3d2134fc0d90379b89da08de7614aef1ac674b1b.tar.bz2
spark-3d2134fc0d90379b89da08de7614aef1ac674b1b.zip
[SPARK-9055][SQL] WidenTypes should also support Intersect and Except
JIRA: https://issues.apache.org/jira/browse/SPARK-9055 cc rxin Author: Yijie Shen <henry.yijieshen@gmail.com> Closes #7491 from yijieshen/widen and squashes the following commits: 079fa52 [Yijie Shen] widenType support for intersect and expect
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala93
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala34
3 files changed, 94 insertions, 41 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 50db7d21f0..ff20835e82 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import javax.annotation.Nullable
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
@@ -168,52 +168,65 @@ object HiveTypeCoercion {
* - LongType to DoubleType
*/
object WidenTypes extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- // TODO: unions with fixed-precision decimals
- case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
- val castedInput = left.output.zip(right.output).map {
- // When a string is found on one side, make the other side a string too.
- case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
- (lhs, Alias(Cast(rhs, StringType), rhs.name)())
- case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
- (Alias(Cast(lhs, StringType), lhs.name)(), rhs)
- case (lhs, rhs) if lhs.dataType != rhs.dataType =>
- logDebug(s"Resolving mismatched union input ${lhs.dataType}, ${rhs.dataType}")
- findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
- val newLeft =
- if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)()
- val newRight =
- if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)()
-
- (newLeft, newRight)
- }.getOrElse {
- // If there is no applicable conversion, leave expression unchanged.
- (lhs, rhs)
- }
+ private[this] def widenOutputTypes(planName: String, left: LogicalPlan, right: LogicalPlan):
+ (LogicalPlan, LogicalPlan) = {
+
+ // TODO: with fixed-precision decimals
+ val castedInput = left.output.zip(right.output).map {
+ // When a string is found on one side, make the other side a string too.
+ case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
+ (lhs, Alias(Cast(rhs, StringType), rhs.name)())
+ case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
+ (Alias(Cast(lhs, StringType), lhs.name)(), rhs)
+
+ case (lhs, rhs) if lhs.dataType != rhs.dataType =>
+ logDebug(s"Resolving mismatched $planName input ${lhs.dataType}, ${rhs.dataType}")
+ findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
+ val newLeft =
+ if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)()
+ val newRight =
+ if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)()
+
+ (newLeft, newRight)
+ }.getOrElse {
+ // If there is no applicable conversion, leave expression unchanged.
+ (lhs, rhs)
+ }
- case other => other
- }
+ case other => other
+ }
- val (castedLeft, castedRight) = castedInput.unzip
+ val (castedLeft, castedRight) = castedInput.unzip
- val newLeft =
- if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
- logDebug(s"Widening numeric types in union $castedLeft ${left.output}")
- Project(castedLeft, left)
- } else {
- left
- }
+ val newLeft =
+ if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
+ logDebug(s"Widening numeric types in $planName $castedLeft ${left.output}")
+ Project(castedLeft, left)
+ } else {
+ left
+ }
- val newRight =
- if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
- logDebug(s"Widening numeric types in union $castedRight ${right.output}")
- Project(castedRight, right)
- } else {
- right
- }
+ val newRight =
+ if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
+ logDebug(s"Widening numeric types in $planName $castedRight ${right.output}")
+ Project(castedRight, right)
+ } else {
+ right
+ }
+ (newLeft, newRight)
+ }
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
+ val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right)
Union(newLeft, newRight)
+ case e @ Except(left, right) if e.childrenResolved && !e.resolved =>
+ val (newLeft, newRight) = widenOutputTypes(e.nodeName, left, right)
+ Except(newLeft, newRight)
+ case i @ Intersect(left, right) if i.childrenResolved && !i.resolved =>
+ val (newLeft, newRight) = widenOutputTypes(i.nodeName, left, right)
+ Intersect(newLeft, newRight)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 17a9124732..986c315b31 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -141,6 +141,10 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
override def output: Seq[Attribute] = left.output
+
+ override lazy val resolved: Boolean =
+ childrenResolved &&
+ left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
}
case class InsertIntoTable(
@@ -437,4 +441,8 @@ case object OneRowRelation extends LeafNode {
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
override def output: Seq[Attribute] = left.output
+
+ override lazy val resolved: Boolean =
+ childrenResolved &&
+ left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index d0fd033b98..c9b3c69c6d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
@@ -305,6 +305,38 @@ class HiveTypeCoercionSuite extends PlanTest {
)
}
+ test("WidenTypes for union except and intersect") {
+ def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
+ logical.output.zip(expectTypes).foreach { case (attr, dt) =>
+ assert(attr.dataType === dt)
+ }
+ }
+
+ val left = LocalRelation(
+ AttributeReference("i", IntegerType)(),
+ AttributeReference("u", DecimalType.Unlimited)(),
+ AttributeReference("b", ByteType)(),
+ AttributeReference("d", DoubleType)())
+ val right = LocalRelation(
+ AttributeReference("s", StringType)(),
+ AttributeReference("d", DecimalType(2, 1))(),
+ AttributeReference("f", FloatType)(),
+ AttributeReference("l", LongType)())
+
+ val wt = HiveTypeCoercion.WidenTypes
+ val expectedTypes = Seq(StringType, DecimalType.Unlimited, FloatType, DoubleType)
+
+ val r1 = wt(Union(left, right)).asInstanceOf[Union]
+ val r2 = wt(Except(left, right)).asInstanceOf[Except]
+ val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect]
+ checkOutput(r1.left, expectedTypes)
+ checkOutput(r1.right, expectedTypes)
+ checkOutput(r2.left, expectedTypes)
+ checkOutput(r2.right, expectedTypes)
+ checkOutput(r3.left, expectedTypes)
+ checkOutput(r3.right, expectedTypes)
+ }
+
/**
* There are rules that need to not fire before child expressions get resolved.
* We use this test to make sure those rules do not fire early.