aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala44
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala95
3 files changed, 137 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index dd68d60d3e..c14f353517 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1031,6 +1031,37 @@ class Analyzer(
}
}
+ // SPARK-17348: A potential incorrect result case.
+ // When a correlated predicate is a non-equality predicate,
+ // certain operators are not permitted from the operator
+ // hosting the correlated predicate up to the operator on the outer table.
+ // Otherwise, the pull up of the correlated predicate
+ // will generate a plan with a different semantics
+ // which could return incorrect result.
+ // Currently we check for Aggregate and Window operators
+ //
+ // Below shows an example of a Logical Plan during Analyzer phase that
+ // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..]
+ // through the Aggregate (or Window) operator could alter the result of
+ // the Aggregate.
+ //
+ // Project [c1#76]
+ // +- Project [c1#87, c2#88]
+ // : (Aggregate or Window operator)
+ // : +- Filter [outer(c2#77) >= c2#88)]
+ // : +- SubqueryAlias t2, `t2`
+ // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88]
+ // : +- LocalRelation [_1#84, _2#85]
+ // +- SubqueryAlias t1, `t1`
+ // +- Project [_1#73 AS c1#76, _2#74 AS c2#77]
+ // +- LocalRelation [_1#73, _2#74]
+ def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = {
+ if (found) {
+ // Report a non-supported case as an exception
+ failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p")
+ }
+ }
+
/** Determine which correlated predicate references are missing from this plan. */
def missingReferences(p: LogicalPlan): AttributeSet = {
val localPredicateReferences = p.collect(predicateMap)
@@ -1041,12 +1072,20 @@ class Analyzer(
localPredicateReferences -- p.outputSet
}
+ var foundNonEqualCorrelatedPred : Boolean = false
+
// Simplify the predicates before pulling them out.
val transformed = BooleanSimplification(sub) transformUp {
case f @ Filter(cond, child) =>
// Find all predicates with an outer reference.
val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter)
+ // Find any non-equality correlated predicates
+ foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists {
+ case _: EqualTo | _: EqualNullSafe => false
+ case _ => true
+ }
+
// Rewrite the filter without the correlated predicates if any.
correlated match {
case Nil => f
@@ -1068,12 +1107,17 @@ class Analyzer(
}
case a @ Aggregate(grouping, expressions, child) =>
failOnOuterReference(a)
+ failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a)
+
val referencesToAdd = missingReferences(a)
if (referencesToAdd.nonEmpty) {
Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child)
} else {
a
}
+ case w : Window =>
+ failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, w)
+ w
case j @ Join(left, _, RightOuter, _) =>
failOnOuterReference(j)
failOnOuterReferenceInSubTree(left, "a RIGHT OUTER JOIN")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 3455a567b7..7b75c1f709 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -119,13 +119,6 @@ trait CheckAnalysis extends PredicateHelper {
}
case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty =>
- // Make sure we are using equi-joins.
- conditions.foreach {
- case _: EqualTo | _: EqualNullSafe => // ok
- case e => failAnalysis(
- s"The correlated scalar subquery can only contain equality predicates: $e")
- }
-
// Make sure correlated scalar subqueries contain one row for every outer row by
// enforcing that they are aggregates which contain exactly one aggregate expressions.
// The analyzer has already checked that subquery contained only one output column, and
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 8934866834..c84a6f1618 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -498,10 +498,10 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
test("non-equal correlated scalar subquery") {
val msg1 = intercept[AnalysisException] {
- sql("select a, (select b from l l2 where l2.a < l1.a) sum_b from l l1")
+ sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1")
}
assert(msg1.getMessage.contains(
- "The correlated scalar subquery can only contain equality predicates"))
+ "Correlated column is not allowed in a non-equality predicate:"))
}
test("disjunctive correlated scalar subquery") {
@@ -639,6 +639,97 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
| from t1 left join t2 on t1.c1=t2.c2) t3
| where c3 not in (select c2 from t2)""".stripMargin),
Row(2) :: Nil)
+ }
+ }
+
+ test("SPARK-17348: Correlated subqueries with non-equality predicate (good case)") {
+ withTempView("t1", "t2") {
+ Seq((1, 1)).toDF("c1", "c2").createOrReplaceTempView("t1")
+ Seq((1, 1), (2, 0)).toDF("c1", "c2").createOrReplaceTempView("t2")
+
+ // Simple case
+ checkAnswer(
+ sql(
+ """
+ | select c1
+ | from t1
+ | where c1 in (select t2.c1
+ | from t2
+ | where t1.c2 >= t2.c2)""".stripMargin),
+ Row(1) :: Nil)
+
+ // More complex case with OR predicate
+ checkAnswer(
+ sql(
+ """
+ | select t1.c1
+ | from t1, t1 as t3
+ | where t1.c1 = t3.c1
+ | and (t1.c1 in (select t2.c1
+ | from t2
+ | where t1.c2 >= t2.c2
+ | or t3.c2 < t2.c2)
+ | or t1.c2 >= 0)""".stripMargin),
+ Row(1) :: Nil)
+ }
+ }
+
+ test("SPARK-17348: Correlated subqueries with non-equality predicate (error case)") {
+ withTempView("t1", "t2", "t3", "t4") {
+ Seq((1, 1)).toDF("c1", "c2").createOrReplaceTempView("t1")
+ Seq((1, 1), (2, 0)).toDF("c1", "c2").createOrReplaceTempView("t2")
+ Seq((2, 1)).toDF("c1", "c2").createOrReplaceTempView("t3")
+ Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t4")
+
+ // Simplest case
+ intercept[AnalysisException] {
+ sql(
+ """
+ | select t1.c1
+ | from t1
+ | where t1.c1 in (select max(t2.c1)
+ | from t2
+ | where t1.c2 >= t2.c2)""".stripMargin).collect()
+ }
+
+ // Add a HAVING on top and augmented within an OR predicate
+ intercept[AnalysisException] {
+ sql(
+ """
+ | select t1.c1
+ | from t1
+ | where t1.c1 in (select max(t2.c1)
+ | from t2
+ | where t1.c2 >= t2.c2
+ | having count(*) > 0 )
+ | or t1.c2 >= 0""".stripMargin).collect()
+ }
+
+ // Add a HAVING on top and augmented within an OR predicate
+ intercept[AnalysisException] {
+ sql(
+ """
+ | select t1.c1
+ | from t1, t1 as t3
+ | where t1.c1 = t3.c1
+ | and (t1.c1 in (select max(t2.c1)
+ | from t2
+ | where t1.c2 = t2.c2
+ | or t3.c2 = t2.c2)
+ | )""".stripMargin).collect()
+ }
+
+ // In Window expression: changing the data set to
+ // demonstrate if this query ran, it would return incorrect result.
+ intercept[AnalysisException] {
+ sql(
+ """
+ | select c1
+ | from t3
+ | where c1 in (select max(t4.c1) over ()
+ | from t4
+ | where t3.c2 >= t4.c2)""".stripMargin).collect()
+ }
}
}
}