diff options
3 files changed, 31 insertions, 21 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 9c06069f24..9a7c2a944b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -287,7 +287,8 @@ trait CheckAnalysis extends PredicateHelper { } // Check if the data types match. dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => - if (dt1 != dt2) { + // SPARK-18058: we shall not care about the nullability of columns + if (dt1.asNullable != dt2.asNullable) { failAnalysis( s""" |${operator.nodeName} can only be performed on tables with the compatible diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d2d33e40a8..64a787a7ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -117,6 +117,8 @@ case class Filter(condition: Expression, child: LogicalPlan) abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + protected def leftConstraints: Set[Expression] = left.constraints protected def rightConstraints: Set[Expression] = { @@ -126,6 +128,13 @@ abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends Binar case a: Attribute => attributeRewrites(a) }) } + + override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => + l.dataType.asNullable == r.dataType.asNullable + } && duplicateResolved } object SetOperation { @@ -134,8 +143,6 @@ object SetOperation { case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - override def output: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) @@ -144,14 +151,6 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation override protected def validConstraints: Set[Expression] = leftConstraints.union(rightConstraints) - // Intersect are only resolved if they don't introduce ambiguous expression ids, - // since the Optimizer will convert Intersect to Join. - override lazy val resolved: Boolean = - childrenResolved && - left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && - duplicateResolved - override def maxRows: Option[Long] = { if (children.exists(_.maxRows.isEmpty)) { None @@ -172,19 +171,11 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output override protected def validConstraints: Set[Expression] = leftConstraints - override lazy val resolved: Boolean = - childrenResolved && - left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && - duplicateResolved - override lazy val statistics: Statistics = { left.statistics.copy() } @@ -219,9 +210,8 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { child.output.length == children.head.output.length && // compare the data types with the first child child.output.zip(children.head.output).forall { - case (l, r) => l.dataType == r.dataType } + case (l, r) => l.dataType.asNullable == r.dataType.asNullable } ) - children.length > 1 && childrenResolved && allChildrenCompatible } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 50ebad25cd..590774c043 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -377,4 +377,23 @@ class AnalysisSuite extends AnalysisTest { assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType) assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType) } + + test("SPARK-18058: union and set operations shall not care about the nullability" + + " when comparing column types") { + val firstTable = LocalRelation( + AttributeReference("a", + StructType(Seq(StructField("a", IntegerType, nullable = true))), nullable = false)()) + val secondTable = LocalRelation( + AttributeReference("a", + StructType(Seq(StructField("a", IntegerType, nullable = false))), nullable = false)()) + + val unionPlan = Union(firstTable, secondTable) + assertAnalysisSuccess(unionPlan) + + val r1 = Except(firstTable, secondTable) + val r2 = Intersect(firstTable, secondTable) + + assertAnalysisSuccess(r1) + assertAnalysisSuccess(r2) + } } |