aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMikhail Bautin <mbautin@gmail.com>2015-11-23 22:26:08 -0800
committerReynold Xin <rxin@databricks.com>2015-11-23 22:26:08 -0800
commit4021a28ac30b65cb61cf1e041253847253a2d89f (patch)
tree660f0c887e43f95936eefd6fa05031572d261195 /sql
parent6cf51a7007bd72eb93ade149ca9fc53be5b32a17 (diff)
downloadspark-4021a28ac30b65cb61cf1e041253847253a2d89f.tar.gz
spark-4021a28ac30b65cb61cf1e041253847253a2d89f.tar.bz2
spark-4021a28ac30b65cb61cf1e041253847253a2d89f.zip
[SPARK-10707][SQL] Fix nullability computation in union output
Author: Mikhail Bautin <mbautin@gmail.com> Closes #9308 from mbautin/SPARK-10707.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala31
3 files changed, 46 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 0c444482c5..737e62fd59 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -92,8 +92,10 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
}
abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
- // TODO: These aren't really the same attributes as nullability etc might change.
- final override def output: Seq[Attribute] = left.output
+ override def output: Seq[Attribute] =
+ left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
+ leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable)
+ }
final override lazy val resolved: Boolean =
childrenResolved &&
@@ -115,7 +117,10 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(lef
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right)
-case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right)
+case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
+ /** We don't use right.output because those rows get excluded from the set. */
+ override def output: Seq[Attribute] = left.output
+}
case class Join(
left: LogicalPlan,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index e79092efda..d57b8e7a9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -130,8 +130,13 @@ case class Sample(
* Union two plans, without a distinct. This is UNION ALL in SQL.
*/
case class Union(children: Seq[SparkPlan]) extends SparkPlan {
- // TODO: attributes output by union should be distinct for nullability purposes
- override def output: Seq[Attribute] = children.head.output
+ override def output: Seq[Attribute] = {
+ children.tail.foldLeft(children.head.output) { case (currentOutput, child) =>
+ currentOutput.zip(child.output).map { case (a1, a2) =>
+ a1.withNullability(a1.nullable || a2.nullable)
+ }
+ }
+ }
override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows)
override def canProcessUnsafeRows: Boolean = true
override def canProcessSafeRows: Boolean = true
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 167aea87de..bb82b562aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1997,4 +1997,35 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
}
+
+ test("SPARK-10707: nullability should be correctly propagated through set operations (1)") {
+ // This test produced an incorrect result of 1 before the SPARK-10707 fix because of the
+ // NullPropagation rule: COUNT(v) got replaced with COUNT(1) because the output column of
+ // UNION was incorrectly considered non-nullable:
+ checkAnswer(
+ sql("""SELECT count(v) FROM (
+ | SELECT v FROM (
+ | SELECT 'foo' AS v UNION ALL
+ | SELECT NULL AS v
+ | ) my_union WHERE isnull(v)
+ |) my_subview""".stripMargin),
+ Seq(Row(0)))
+ }
+
+ test("SPARK-10707: nullability should be correctly propagated through set operations (2)") {
+ // This test uses RAND() to stop column pruning for Union and checks the resulting isnull
+ // value. This would produce an incorrect result before the fix in SPARK-10707 because the "v"
+ // column of the union was considered non-nullable.
+ checkAnswer(
+ sql(
+ """
+ |SELECT a FROM (
+ | SELECT ISNULL(v) AS a, RAND() FROM (
+ | SELECT 'foo' AS v UNION ALL SELECT null AS v
+ | ) my_union
+ |) my_view
+ """.stripMargin),
+ Row(false) :: Row(true) :: Nil)
+ }
+
}