aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorNattavut Sutyanyong <nsy.can@gmail.com>2016-11-22 12:06:21 -0800
committerHerman van Hovell <hvanhovell@databricks.com>2016-11-22 12:06:21 -0800
commit45ea46b7b397f023b4da878eb11e21b08d931115 (patch)
tree51be6bfe31812109263bac69f947ef315b5c084c /sql
parentbb152cdfbb8d02130c71d2326ae81939725c2cf0 (diff)
downloadspark-45ea46b7b397f023b4da878eb11e21b08d931115.tar.gz
spark-45ea46b7b397f023b4da878eb11e21b08d931115.tar.bz2
spark-45ea46b7b397f023b4da878eb11e21b08d931115.zip
[SPARK-18504][SQL] Scalar subquery with extra group by columns returning incorrect result
## What changes were proposed in this pull request? This PR blocks an incorrect result scenario in scalar subquery where there are GROUP BY column(s) that are not part of the correlated predicate(s). Example: // Incorrect result Seq(1).toDF("c1").createOrReplaceTempView("t1") Seq((1,1),(1,2)).toDF("c1","c2").createOrReplaceTempView("t2") sql("select (select sum(-1) from t2 where t1.c1=t2.c1 group by t2.c2) from t1").show // How can selecting a scalar subquery from a 1-row table return 2 rows? ## How was this patch tested? sql/test, catalyst/test new test case covering the reported problem is added to SubquerySuite.scala Author: Nattavut Sutyanyong <nsy.can@gmail.com> Closes #15936 from nsyca/scalarSubqueryIncorrect-1.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala30
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala12
3 files changed, 36 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index ec5f710fd9..0155741ddb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1241,9 +1241,6 @@ class Analyzer(
*/
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
plan transformExpressions {
- case s @ ScalarSubquery(sub, conditions, exprId)
- if sub.resolved && conditions.isEmpty && sub.output.size != 1 =>
- failAnalysis(s"Scalar subquery must return only one column, but got ${sub.output.size}")
case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, exprId) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 80e577e5c4..26d2638590 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -117,19 +117,37 @@ trait CheckAnalysis extends PredicateHelper {
failAnalysis(s"Window specification $s is not valid because $m")
case None => w
}
+ case s @ ScalarSubquery(query, conditions, _)
+ // If no correlation, the output must be exactly one column
+ if (conditions.isEmpty && query.output.size != 1) =>
+ failAnalysis(
+ s"Scalar subquery must return only one column, but got ${query.output.size}")
case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty =>
- // Make sure correlated scalar subqueries contain one row for every outer row by
- // enforcing that they are aggregates which contain exactly one aggregate expressions.
- // The analyzer has already checked that subquery contained only one output column, and
- // added all the grouping expressions to the aggregate.
- def checkAggregate(a: Aggregate): Unit = {
- val aggregates = a.expressions.flatMap(_.collect {
+ def checkAggregate(agg: Aggregate): Unit = {
+ // Make sure correlated scalar subqueries contain one row for every outer row by
+ // enforcing that they are aggregates which contain exactly one aggregate expressions.
+ // The analyzer has already checked that subquery contained only one output column,
+ // and added all the grouping expressions to the aggregate.
+ val aggregates = agg.expressions.flatMap(_.collect {
case a: AggregateExpression => a
})
if (aggregates.isEmpty) {
failAnalysis("The output of a correlated scalar subquery must be aggregated")
}
+
+ // SPARK-18504: block cases where GROUP BY columns
+ // are not part of the correlated columns
+ val groupByCols = ExpressionSet.apply(agg.groupingExpressions.flatMap(_.references))
+ val predicateCols = ExpressionSet.apply(conditions.flatMap(_.references))
+ val invalidCols = groupByCols.diff(predicateCols)
+ // GROUP BY columns must be a subset of columns in the predicates
+ if (invalidCols.nonEmpty) {
+ failAnalysis(
+ "a GROUP BY clause in a scalar correlated subquery " +
+ "cannot contain non-correlated columns: " +
+ invalidCols.mkString(","))
+ }
}
// Skip projects and subquery aliases added by the Analyzer and the SQLBuilder.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index c84a6f1618..f1dd1c620e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -483,6 +483,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
Row(1, null) :: Row(2, 6.0) :: Row(3, 2.0) :: Row(null, null) :: Row(6, null) :: Nil)
}
+ test("SPARK-18504 extra GROUP BY column in correlated scalar subquery is not permitted") {
+ withTempView("t") {
+ Seq((1, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
+
+ val errMsg = intercept[AnalysisException] {
+ sql("select (select sum(-1) from t t2 where t1.c2 = t2.c1 group by t2.c2) sum from t t1")
+ }
+ assert(errMsg.getMessage.contains(
+ "a GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:"))
+ }
+ }
+
test("non-aggregated correlated scalar subquery") {
val msg1 = intercept[AnalysisException] {
sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1")