aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala30
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala19
3 files changed, 31 insertions, 21 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 9c06069f24..9a7c2a944b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -287,7 +287,8 @@ trait CheckAnalysis extends PredicateHelper {
}
// Check if the data types match.
dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) =>
- if (dt1 != dt2) {
+ // SPARK-18058: we shall not care about the nullability of columns
+ if (dt1.asNullable != dt2.asNullable) {
failAnalysis(
s"""
|${operator.nodeName} can only be performed on tables with the compatible
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index d2d33e40a8..64a787a7ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -117,6 +117,8 @@ case class Filter(condition: Expression, child: LogicalPlan)
abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
+ def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
+
protected def leftConstraints: Set[Expression] = left.constraints
protected def rightConstraints: Set[Expression] = {
@@ -126,6 +128,13 @@ abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends Binar
case a: Attribute => attributeRewrites(a)
})
}
+
+ override lazy val resolved: Boolean =
+ childrenResolved &&
+ left.output.length == right.output.length &&
+ left.output.zip(right.output).forall { case (l, r) =>
+ l.dataType.asNullable == r.dataType.asNullable
+ } && duplicateResolved
}
object SetOperation {
@@ -134,8 +143,6 @@ object SetOperation {
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
- def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
-
override def output: Seq[Attribute] =
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
@@ -144,14 +151,6 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
override protected def validConstraints: Set[Expression] =
leftConstraints.union(rightConstraints)
- // Intersect are only resolved if they don't introduce ambiguous expression ids,
- // since the Optimizer will convert Intersect to Join.
- override lazy val resolved: Boolean =
- childrenResolved &&
- left.output.length == right.output.length &&
- left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } &&
- duplicateResolved
-
override def maxRows: Option[Long] = {
if (children.exists(_.maxRows.isEmpty)) {
None
@@ -172,19 +171,11 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
- def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
-
/** We don't use right.output because those rows get excluded from the set. */
override def output: Seq[Attribute] = left.output
override protected def validConstraints: Set[Expression] = leftConstraints
- override lazy val resolved: Boolean =
- childrenResolved &&
- left.output.length == right.output.length &&
- left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } &&
- duplicateResolved
-
override lazy val statistics: Statistics = {
left.statistics.copy()
}
@@ -219,9 +210,8 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
child.output.length == children.head.output.length &&
// compare the data types with the first child
child.output.zip(children.head.output).forall {
- case (l, r) => l.dataType == r.dataType }
+ case (l, r) => l.dataType.asNullable == r.dataType.asNullable }
)
-
children.length > 1 && childrenResolved && allChildrenCompatible
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 50ebad25cd..590774c043 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -377,4 +377,23 @@ class AnalysisSuite extends AnalysisTest {
assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType)
assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType)
}
+
+ test("SPARK-18058: union and set operations shall not care about the nullability" +
+ " when comparing column types") {
+ val firstTable = LocalRelation(
+ AttributeReference("a",
+ StructType(Seq(StructField("a", IntegerType, nullable = true))), nullable = false)())
+ val secondTable = LocalRelation(
+ AttributeReference("a",
+ StructType(Seq(StructField("a", IntegerType, nullable = false))), nullable = false)())
+
+ val unionPlan = Union(firstTable, secondTable)
+ assertAnalysisSuccess(unionPlan)
+
+ val r1 = Except(firstTable, secondTable)
+ val r2 = Intersect(firstTable, secondTable)
+
+ assertAnalysisSuccess(r1)
+ assertAnalysisSuccess(r2)
+ }
}