aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala30
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala12
3 files changed, 36 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index ec5f710fd9..0155741ddb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1241,9 +1241,6 @@ class Analyzer(
*/
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
plan transformExpressions {
- case s @ ScalarSubquery(sub, conditions, exprId)
- if sub.resolved && conditions.isEmpty && sub.output.size != 1 =>
- failAnalysis(s"Scalar subquery must return only one column, but got ${sub.output.size}")
case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, exprId) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 80e577e5c4..26d2638590 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -117,19 +117,37 @@ trait CheckAnalysis extends PredicateHelper {
failAnalysis(s"Window specification $s is not valid because $m")
case None => w
}
+ case s @ ScalarSubquery(query, conditions, _)
+ // If no correlation, the output must be exactly one column
+ if (conditions.isEmpty && query.output.size != 1) =>
+ failAnalysis(
+ s"Scalar subquery must return only one column, but got ${query.output.size}")
case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty =>
- // Make sure correlated scalar subqueries contain one row for every outer row by
- // enforcing that they are aggregates which contain exactly one aggregate expressions.
- // The analyzer has already checked that subquery contained only one output column, and
- // added all the grouping expressions to the aggregate.
- def checkAggregate(a: Aggregate): Unit = {
- val aggregates = a.expressions.flatMap(_.collect {
+ def checkAggregate(agg: Aggregate): Unit = {
+ // Make sure correlated scalar subqueries contain one row for every outer row by
+ // enforcing that they are aggregates which contain exactly one aggregate expressions.
+ // The analyzer has already checked that subquery contained only one output column,
+ // and added all the grouping expressions to the aggregate.
+ val aggregates = agg.expressions.flatMap(_.collect {
case a: AggregateExpression => a
})
if (aggregates.isEmpty) {
failAnalysis("The output of a correlated scalar subquery must be aggregated")
}
+
+ // SPARK-18504: block cases where GROUP BY columns
+ // are not part of the correlated columns
+ val groupByCols = ExpressionSet.apply(agg.groupingExpressions.flatMap(_.references))
+ val predicateCols = ExpressionSet.apply(conditions.flatMap(_.references))
+ val invalidCols = groupByCols.diff(predicateCols)
+ // GROUP BY columns must be a subset of columns in the predicates
+ if (invalidCols.nonEmpty) {
+ failAnalysis(
+ "a GROUP BY clause in a scalar correlated subquery " +
+ "cannot contain non-correlated columns: " +
+ invalidCols.mkString(","))
+ }
}
// Skip projects and subquery aliases added by the Analyzer and the SQLBuilder.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index c84a6f1618..f1dd1c620e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -483,6 +483,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
Row(1, null) :: Row(2, 6.0) :: Row(3, 2.0) :: Row(null, null) :: Row(6, null) :: Nil)
}
+ test("SPARK-18504 extra GROUP BY column in correlated scalar subquery is not permitted") {
+ withTempView("t") {
+ Seq((1, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
+
+ val errMsg = intercept[AnalysisException] {
+ sql("select (select sum(-1) from t t2 where t1.c2 = t2.c1 group by t2.c2) sum from t t1")
+ }
+ assert(errMsg.getMessage.contains(
+ "a GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:"))
+ }
+ }
+
test("non-aggregated correlated scalar subquery") {
val msg1 = intercept[AnalysisException] {
sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1")