diff options
3 files changed, 36 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ec5f710fd9..0155741ddb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1241,9 +1241,6 @@ class Analyzer( */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { - case s @ ScalarSubquery(sub, conditions, exprId) - if sub.resolved && conditions.isEmpty && sub.output.size != 1 => - failAnalysis(s"Scalar subquery must return only one column, but got ${sub.output.size}") case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, exprId) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 80e577e5c4..26d2638590 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -117,19 +117,37 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis(s"Window specification $s is not valid because $m") case None => w } + case s @ ScalarSubquery(query, conditions, _) + // If no correlation, the output must be exactly one column + if (conditions.isEmpty && query.output.size != 1) => + failAnalysis( + s"Scalar subquery must return only one column, but got ${query.output.size}") case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty => - // Make sure correlated scalar subqueries contain one row for every outer row by - // enforcing that they are aggregates which contain exactly one aggregate expressions. - // The analyzer has already checked that subquery contained only one output column, and - // added all the grouping expressions to the aggregate. - def checkAggregate(a: Aggregate): Unit = { - val aggregates = a.expressions.flatMap(_.collect { + def checkAggregate(agg: Aggregate): Unit = { + // Make sure correlated scalar subqueries contain one row for every outer row by + // enforcing that they are aggregates which contain exactly one aggregate expressions. + // The analyzer has already checked that subquery contained only one output column, + // and added all the grouping expressions to the aggregate. + val aggregates = agg.expressions.flatMap(_.collect { case a: AggregateExpression => a }) if (aggregates.isEmpty) { failAnalysis("The output of a correlated scalar subquery must be aggregated") } + + // SPARK-18504: block cases where GROUP BY columns + // are not part of the correlated columns + val groupByCols = ExpressionSet.apply(agg.groupingExpressions.flatMap(_.references)) + val predicateCols = ExpressionSet.apply(conditions.flatMap(_.references)) + val invalidCols = groupByCols.diff(predicateCols) + // GROUP BY columns must be a subset of columns in the predicates + if (invalidCols.nonEmpty) { + failAnalysis( + "a GROUP BY clause in a scalar correlated subquery " + + "cannot contain non-correlated columns: " + + invalidCols.mkString(",")) + } } // Skip projects and subquery aliases added by the Analyzer and the SQLBuilder. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index c84a6f1618..f1dd1c620e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -483,6 +483,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row(1, null) :: Row(2, 6.0) :: Row(3, 2.0) :: Row(null, null) :: Row(6, null) :: Nil) } + test("SPARK-18504 extra GROUP BY column in correlated scalar subquery is not permitted") { + withTempView("t") { + Seq((1, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") + + val errMsg = intercept[AnalysisException] { + sql("select (select sum(-1) from t t2 where t1.c2 = t2.c1 group by t2.c2) sum from t t1") + } + assert(errMsg.getMessage.contains( + "a GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:")) + } + } + test("non-aggregated correlated scalar subquery") { val msg1 = intercept[AnalysisException] { sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1") |