aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala9
2 files changed, 14 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 310d127506..b4c445b3ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -141,10 +141,11 @@ object PartialAggregation {
// We need to pass all grouping expressions though so the grouping can happen a second
// time. However some of them might be unnamed so we alias them allowing them to be
// referenced in the second aggregation.
- val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map {
- case n: NamedExpression => (n, n)
- case other => (other, Alias(other, "PartialGroup")())
- }.toMap
+ val namedGroupingExpressions: Map[Expression, NamedExpression] =
+ groupingExpressions.filter(!_.isInstanceOf[Literal]).map {
+ case n: NamedExpression => (n, n)
+ case other => (other, Alias(other, "PartialGroup")())
+ }.toMap
// Replace aggregations with a new expression that computes the result from the already
// computed partial evaluations and grouping values.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index e03444d496..d684278f11 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -186,6 +186,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Seq(Row(1,3), Row(2,3), Row(3,3)))
}
+ test("literal in agg grouping expressions") {
+ checkAnswer(
+ sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
+ Seq(Row(1,2), Row(2,2), Row(3,2)))
+ checkAnswer(
+ sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
+ Seq(Row(1,2), Row(2,2), Row(3,2)))
+ }
+
test("aggregates with nulls") {
checkAnswer(
sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"),