aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTakeshi YAMAMURO <linguin.m.s@gmail.com>2016-06-01 22:23:00 -0700
committerCheng Lian <lian@databricks.com>2016-06-01 22:23:00 -0700
commit5eea332307cbed5fc44427959f070afc16a12c02 (patch)
tree8f2cc80645a769fa7908615da3b5c590af8eeab7
parent8288e16a5a5a12a45207c13a1341c707c6b4b940 (diff)
downloadspark-5eea332307cbed5fc44427959f070afc16a12c02.tar.gz
spark-5eea332307cbed5fc44427959f070afc16a12c02.tar.bz2
spark-5eea332307cbed5fc44427959f070afc16a12c02.zip
[SPARK-13484][SQL] Prevent illegal NULL propagation when filtering outer-join results
## What changes were proposed in this pull request? This PR add a rule at the end of analyzer to correct nullable fields of attributes in a logical plan by using nullable fields of the corresponding attributes in its children logical plans (these plans generate the input rows). This is another approach for addressing SPARK-13484 (the first approach is https://github.com/apache/spark/pull/11371). Close #113711 Author: Takeshi YAMAMURO <linguin.m.s@gmail.com> Author: Yin Huai <yhuai@databricks.com> Closes #13290 from yhuai/SPARK-13484.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala37
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala21
3 files changed, 58 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index eb46c0e72e..02966796af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -113,6 +113,8 @@ class Analyzer(
PullOutNondeterministic),
Batch("UDF", Once,
HandleNullInputsForUDF),
+ Batch("FixNullability", Once,
+ FixNullability),
Batch("Cleanup", fixedPoint,
CleanupAliases)
)
@@ -1452,6 +1454,40 @@ class Analyzer(
}
/**
+ * Fixes nullability of Attributes in a resolved LogicalPlan by using the nullability of
+ * corresponding Attributes of its children output Attributes. This step is needed because
+ * users can use a resolved AttributeReference in the Dataset API and outer joins
+ * can change the nullability of an AttribtueReference. Without the fix, a nullable column's
+ * nullable field can be actually set as non-nullable, which cause illegal optimization
+ * (e.g., NULL propagation) and wrong answers.
+ * See SPARK-13484 and SPARK-13801 for the concrete queries of this case.
+ */
+ object FixNullability extends Rule[LogicalPlan] {
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case p if !p.resolved => p // Skip unresolved nodes.
+ case p: LogicalPlan if p.resolved =>
+ val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap {
+ case (exprId, attributes) =>
+ // If there are multiple Attributes having the same ExprId, we need to resolve
+ // the conflict of nullable field. We do not really expect this happen.
+ val nullable = attributes.exists(_.nullable)
+ attributes.map(attr => attr.withNullability(nullable))
+ }.toSeq
+ // At here, we create an AttributeMap that only compare the exprId for the lookup
+ // operation. So, we can find the corresponding input attribute's nullability.
+ val attributeMap = AttributeMap[Attribute](childrenOutput.map(attr => attr -> attr))
+ // For an Attribute used by the current LogicalPlan, if it is from its children,
+ // we fix the nullable field by using the nullability setting of the corresponding
+ // output Attribute from the children.
+ p.transformExpressions {
+ case attr: Attribute if attributeMap.contains(attr) =>
+ attr.withNullability(attributeMap(attr).nullable)
+ }
+ }
+ }
+
+ /**
* Extracts [[WindowExpression]]s from the projectList of a [[Project]] operator and
* aggregateExpressions of an [[Aggregate]] operator and creates individual [[Window]]
* operators for every distinct [[WindowSpecDefinition]].
@@ -2133,4 +2169,3 @@ object TimeWindowing extends Rule[LogicalPlan] {
}
}
}
-
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
index 1423a8705a..748579df41 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
@@ -100,7 +100,7 @@ class ResolveNaturalJoinSuite extends AnalysisTest {
val naturalPlan = r3.join(r4, NaturalJoin(FullOuter), None)
val usingPlan = r3.join(r4, UsingJoin(FullOuter, Seq(UnresolvedAttribute("b"))), None)
val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select(
- Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c)
+ Alias(Coalesce(Seq(b, b)), "b")(), a, c)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 031e66b57c..4342c039ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -204,4 +204,25 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
leftJoin2Inner,
Row(1, 2, "1", 1, 3, "1") :: Nil)
}
+
+ test("process outer join results using the non-nullable columns in the join input") {
+ // Filter data using a non-nullable column from a right table
+ val df1 = Seq((0, 0), (1, 0), (2, 0), (3, 0), (4, 0)).toDF("id", "count")
+ val df2 = Seq(Tuple1(0), Tuple1(1)).toDF("id").groupBy("id").count
+ checkAnswer(
+ df1.join(df2, df1("id") === df2("id"), "left_outer").filter(df2("count").isNull),
+ Row(2, 0, null, null) ::
+ Row(3, 0, null, null) ::
+ Row(4, 0, null, null) :: Nil
+ )
+
+ // Coalesce data using non-nullable columns in input tables
+ val df3 = Seq((1, 1)).toDF("a", "b")
+ val df4 = Seq((2, 2)).toDF("a", "b")
+ checkAnswer(
+ df3.join(df4, df3("a") === df4("a"), "outer")
+ .select(coalesce(df3("a"), df3("b")), coalesce(df4("a"), df4("b"))),
+ Row(1, null) :: Row(null, 2) :: Nil
+ )
+ }
}