aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNattavut Sutyanyong <nsy.can@gmail.com>2016-11-14 20:59:15 +0100
committerHerman van Hovell <hvanhovell@databricks.com>2016-11-14 20:59:15 +0100
commitbd85603ba5f9e61e1aa8326d3e4d5703b5977a4c (patch)
tree19732c8de1a617e2e17543a7cf0d99135cc2850c
parent75934457d75996be71ffd0d4b448497d656c0d40 (diff)
downloadspark-bd85603ba5f9e61e1aa8326d3e4d5703b5977a4c.tar.gz
spark-bd85603ba5f9e61e1aa8326d3e4d5703b5977a4c.tar.bz2
spark-bd85603ba5f9e61e1aa8326d3e4d5703b5977a4c.zip
[SPARK-17348][SQL] Incorrect results from subquery transformation
## What changes were proposed in this pull request? Return an Analysis exception when there is a correlated non-equality predicate in a subquery and the correlated column from the outer reference is not from the immediate parent operator of the subquery. This PR prevents incorrect results from subquery transformation in such case. Test cases, both positive and negative tests, are added. ## How was this patch tested? sql/test, catalyst/test, hive/test, and scenarios that will produce incorrect results without this PR and product correct results when subquery transformation does happen. Author: Nattavut Sutyanyong <nsy.can@gmail.com> Closes #15763 from nsyca/spark-17348.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala44
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala95
3 files changed, 137 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index dd68d60d3e..c14f353517 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1031,6 +1031,37 @@ class Analyzer(
}
}
+ // SPARK-17348: A potential incorrect result case.
+ // When a correlated predicate is a non-equality predicate,
+ // certain operators are not permitted from the operator
+ // hosting the correlated predicate up to the operator on the outer table.
+ // Otherwise, the pull up of the correlated predicate
+ // will generate a plan with a different semantics
+ // which could return incorrect result.
+ // Currently we check for Aggregate and Window operators
+ //
+ // Below shows an example of a Logical Plan during Analyzer phase that
+ // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..]
+ // through the Aggregate (or Window) operator could alter the result of
+ // the Aggregate.
+ //
+ // Project [c1#76]
+ // +- Project [c1#87, c2#88]
+ // : (Aggregate or Window operator)
+ // : +- Filter [outer(c2#77) >= c2#88)]
+ // : +- SubqueryAlias t2, `t2`
+ // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88]
+ // : +- LocalRelation [_1#84, _2#85]
+ // +- SubqueryAlias t1, `t1`
+ // +- Project [_1#73 AS c1#76, _2#74 AS c2#77]
+ // +- LocalRelation [_1#73, _2#74]
+ def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = {
+ if (found) {
+ // Report a non-supported case as an exception
+ failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p")
+ }
+ }
+
/** Determine which correlated predicate references are missing from this plan. */
def missingReferences(p: LogicalPlan): AttributeSet = {
val localPredicateReferences = p.collect(predicateMap)
@@ -1041,12 +1072,20 @@ class Analyzer(
localPredicateReferences -- p.outputSet
}
+ var foundNonEqualCorrelatedPred : Boolean = false
+
// Simplify the predicates before pulling them out.
val transformed = BooleanSimplification(sub) transformUp {
case f @ Filter(cond, child) =>
// Find all predicates with an outer reference.
val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter)
+ // Find any non-equality correlated predicates
+ foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists {
+ case _: EqualTo | _: EqualNullSafe => false
+ case _ => true
+ }
+
// Rewrite the filter without the correlated predicates if any.
correlated match {
case Nil => f
@@ -1068,12 +1107,17 @@ class Analyzer(
}
case a @ Aggregate(grouping, expressions, child) =>
failOnOuterReference(a)
+ failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a)
+
val referencesToAdd = missingReferences(a)
if (referencesToAdd.nonEmpty) {
Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child)
} else {
a
}
+ case w : Window =>
+ failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, w)
+ w
case j @ Join(left, _, RightOuter, _) =>
failOnOuterReference(j)
failOnOuterReferenceInSubTree(left, "a RIGHT OUTER JOIN")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 3455a567b7..7b75c1f709 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -119,13 +119,6 @@ trait CheckAnalysis extends PredicateHelper {
}
case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty =>
- // Make sure we are using equi-joins.
- conditions.foreach {
- case _: EqualTo | _: EqualNullSafe => // ok
- case e => failAnalysis(
- s"The correlated scalar subquery can only contain equality predicates: $e")
- }
-
// Make sure correlated scalar subqueries contain one row for every outer row by
// enforcing that they are aggregates which contain exactly one aggregate expressions.
// The analyzer has already checked that subquery contained only one output column, and
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 8934866834..c84a6f1618 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -498,10 +498,10 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
test("non-equal correlated scalar subquery") {
val msg1 = intercept[AnalysisException] {
- sql("select a, (select b from l l2 where l2.a < l1.a) sum_b from l l1")
+ sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1")
}
assert(msg1.getMessage.contains(
- "The correlated scalar subquery can only contain equality predicates"))
+ "Correlated column is not allowed in a non-equality predicate:"))
}
test("disjunctive correlated scalar subquery") {
@@ -639,6 +639,97 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
| from t1 left join t2 on t1.c1=t2.c2) t3
| where c3 not in (select c2 from t2)""".stripMargin),
Row(2) :: Nil)
+ }
+ }
+
+ test("SPARK-17348: Correlated subqueries with non-equality predicate (good case)") {
+ withTempView("t1", "t2") {
+ Seq((1, 1)).toDF("c1", "c2").createOrReplaceTempView("t1")
+ Seq((1, 1), (2, 0)).toDF("c1", "c2").createOrReplaceTempView("t2")
+
+ // Simple case
+ checkAnswer(
+ sql(
+ """
+ | select c1
+ | from t1
+ | where c1 in (select t2.c1
+ | from t2
+ | where t1.c2 >= t2.c2)""".stripMargin),
+ Row(1) :: Nil)
+
+ // More complex case with OR predicate
+ checkAnswer(
+ sql(
+ """
+ | select t1.c1
+ | from t1, t1 as t3
+ | where t1.c1 = t3.c1
+ | and (t1.c1 in (select t2.c1
+ | from t2
+ | where t1.c2 >= t2.c2
+ | or t3.c2 < t2.c2)
+ | or t1.c2 >= 0)""".stripMargin),
+ Row(1) :: Nil)
+ }
+ }
+
+ test("SPARK-17348: Correlated subqueries with non-equality predicate (error case)") {
+ withTempView("t1", "t2", "t3", "t4") {
+ Seq((1, 1)).toDF("c1", "c2").createOrReplaceTempView("t1")
+ Seq((1, 1), (2, 0)).toDF("c1", "c2").createOrReplaceTempView("t2")
+ Seq((2, 1)).toDF("c1", "c2").createOrReplaceTempView("t3")
+ Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t4")
+
+ // Simplest case
+ intercept[AnalysisException] {
+ sql(
+ """
+ | select t1.c1
+ | from t1
+ | where t1.c1 in (select max(t2.c1)
+ | from t2
+ | where t1.c2 >= t2.c2)""".stripMargin).collect()
+ }
+
+ // Add a HAVING on top and augmented within an OR predicate
+ intercept[AnalysisException] {
+ sql(
+ """
+ | select t1.c1
+ | from t1
+ | where t1.c1 in (select max(t2.c1)
+ | from t2
+ | where t1.c2 >= t2.c2
+ | having count(*) > 0 )
+ | or t1.c2 >= 0""".stripMargin).collect()
+ }
+
+ // Add a HAVING on top and augmented within an OR predicate
+ intercept[AnalysisException] {
+ sql(
+ """
+ | select t1.c1
+ | from t1, t1 as t3
+ | where t1.c1 = t3.c1
+ | and (t1.c1 in (select max(t2.c1)
+ | from t2
+ | where t1.c2 = t2.c2
+ | or t3.c2 = t2.c2)
+ | )""".stripMargin).collect()
+ }
+
+ // In Window expression: changing the data set to
+ // demonstrate if this query ran, it would return incorrect result.
+ intercept[AnalysisException] {
+ sql(
+ """
+ | select c1
+ | from t3
+ | where c1 in (select max(t4.c1) over ()
+ | from t4
+ | where t3.c2 >= t4.c2)""".stripMargin).collect()
+ }
}
}
}