aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorNattavut Sutyanyong <nsy.can@gmail.com>2016-12-14 11:09:31 +0100
committerHerman van Hovell <hvanhovell@databricks.com>2016-12-14 11:09:31 +0100
commitcccd64393ea633e29d4a505fb0a7c01b51a79af8 (patch)
tree66f39048cd7dde5a2362f0bed2fb74a91b341216 /sql
parent3e307b4959ecdab3f9c16484d172403357e7d09b (diff)
downloadspark-cccd64393ea633e29d4a505fb0a7c01b51a79af8.tar.gz
spark-cccd64393ea633e29d4a505fb0a7c01b51a79af8.tar.bz2
spark-cccd64393ea633e29d4a505fb0a7c01b51a79af8.zip
[SPARK-18814][SQL] CheckAnalysis rejects TPCDS query 32
## What changes were proposed in this pull request? Move the checking of GROUP BY column in correlated scalar subquery from CheckAnalysis to Analysis to fix a regression caused by SPARK-18504. This problem can be reproduced with a simple script now. Seq((1,1)).toDF("pk","pv").createOrReplaceTempView("p") Seq((1,1)).toDF("ck","cv").createOrReplaceTempView("c") sql("select * from p,c where p.pk=c.ck and c.cv = (select avg(c1.cv) from c c1 where c1.ck = p.pk)").show The requirements are: 1. We need to reference the same table twice in both the parent and the subquery. Here is the table c. 2. We need to have a correlated predicate but to a different table. Here is from c (as c1) in the subquery to p in the parent. 3. We will then "deduplicate" c1.ck in the subquery to `ck#<n1>#<n2>` at `Project` above `Aggregate` of `avg`. Then when we compare `ck#<n1>#<n2>` and the original group by column `ck#<n1>` by their canonicalized form, which is #<n2> != #<n1>. That's how we trigger the exception added in SPARK-18504. ## How was this patch tested? SubquerySuite and a simplified version of TPCDS-Q32 Author: Nattavut Sutyanyong <nsy.can@gmail.com> Closes #16246 from nsyca/18814.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala31
-rw-r--r--sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql20
-rw-r--r--sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala2
4 files changed, 90 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 235a79973d..aa77a6efef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -124,6 +124,10 @@ trait CheckAnalysis extends PredicateHelper {
s"Scalar subquery must return only one column, but got ${query.output.size}")
case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty =>
+
+ // Collect the columns from the subquery for further checking.
+ var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains)
+
def checkAggregate(agg: Aggregate): Unit = {
// Make sure correlated scalar subqueries contain one row for every outer row by
// enforcing that they are aggregates which contain exactly one aggregate expressions.
@@ -136,24 +140,35 @@ trait CheckAnalysis extends PredicateHelper {
failAnalysis("The output of a correlated scalar subquery must be aggregated")
}
- // SPARK-18504: block cases where GROUP BY columns
- // are not part of the correlated columns
- val groupByCols = ExpressionSet.apply(agg.groupingExpressions.flatMap(_.references))
- val predicateCols = ExpressionSet.apply(conditions.flatMap(_.references))
- val invalidCols = groupByCols.diff(predicateCols)
+ // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
+ // are not part of the correlated columns.
+ val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
+ val correlatedCols = AttributeSet(subqueryColumns)
+ val invalidCols = groupByCols -- correlatedCols
// GROUP BY columns must be a subset of columns in the predicates
if (invalidCols.nonEmpty) {
failAnalysis(
- "a GROUP BY clause in a scalar correlated subquery " +
+ "A GROUP BY clause in a scalar correlated subquery " +
"cannot contain non-correlated columns: " +
invalidCols.mkString(","))
}
}
- // Skip projects and subquery aliases added by the Analyzer and the SQLBuilder.
+ // Skip subquery aliases added by the Analyzer and the SQLBuilder.
+ // For projects, do the necessary mapping and skip to its child.
def cleanQuery(p: LogicalPlan): LogicalPlan = p match {
case s: SubqueryAlias => cleanQuery(s.child)
- case p: Project => cleanQuery(p.child)
+ case p: Project =>
+ // SPARK-18814: Map any aliases to their AttributeReference children
+ // for the checking in the Aggregate operators below this Project.
+ subqueryColumns = subqueryColumns.map {
+ xs => p.projectList.collectFirst {
+ case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId =>
+ child
+ }.getOrElse(xs)
+ }
+
+ cleanQuery(p.child)
case child => child
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql b/sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql
new file mode 100644
index 0000000000..3acc9db09c
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql
@@ -0,0 +1,20 @@
+CREATE OR REPLACE TEMPORARY VIEW p AS VALUES (1, 1) AS T(pk, pv);
+CREATE OR REPLACE TEMPORARY VIEW c AS VALUES (1, 1) AS T(ck, cv);
+
+-- SPARK-18814.1: Simplified version of TPCDS-Q32
+SELECT pk, cv
+FROM p, c
+WHERE p.pk = c.ck
+AND c.cv = (SELECT avg(c1.cv)
+ FROM c c1
+ WHERE c1.ck = p.pk);
+
+-- SPARK-18814.2: Adding stack of aggregates
+SELECT pk, cv
+FROM p, c
+WHERE p.pk = c.ck
+AND c.cv = (SELECT max(avg)
+ FROM (SELECT c1.cv, avg(c1.cv) avg
+ FROM c c1
+ WHERE c1.ck = p.pk
+ GROUP BY c1.cv));
diff --git a/sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out b/sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out
new file mode 100644
index 0000000000..c249329d6a
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out
@@ -0,0 +1,46 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 4
+
+
+-- !query 0
+CREATE OR REPLACE TEMPORARY VIEW p AS VALUES (1, 1) AS T(pk, pv)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+CREATE OR REPLACE TEMPORARY VIEW c AS VALUES (1, 1) AS T(ck, cv)
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+SELECT pk, cv
+FROM p, c
+WHERE p.pk = c.ck
+AND c.cv = (SELECT avg(c1.cv)
+ FROM c c1
+ WHERE c1.ck = p.pk)
+-- !query 2 schema
+struct<pk:int,cv:int>
+-- !query 2 output
+1 1
+
+
+-- !query 3
+SELECT pk, cv
+FROM p, c
+WHERE p.pk = c.ck
+AND c.cv = (SELECT max(avg)
+ FROM (SELECT c1.cv, avg(c1.cv) avg
+ FROM c c1
+ WHERE c1.ck = p.pk
+ GROUP BY c1.cv))
+-- !query 3 schema
+struct<pk:int,cv:int>
+-- !query 3 output
+1 1
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 0f2f520006..5a4b1cfe95 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -491,7 +491,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
sql("select (select sum(-1) from t t2 where t1.c2 = t2.c1 group by t2.c2) sum from t t1")
}
assert(errMsg.getMessage.contains(
- "a GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:"))
+ "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:"))
}
}