aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala18
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala94
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala8
3 files changed, 111 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
index 0f43e7bb88..d6a39ecf53 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
@@ -119,14 +119,16 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
.filter(_.isDistinct)
.groupBy(_.aggregateFunction.children.toSet)
- // Aggregation strategy can handle the query with single distinct
- if (distinctAggGroups.size > 1) {
+ // Check if the aggregates contains functions that do not support partial aggregation.
+ val existsNonPartial = aggExpressions.exists(!_.aggregateFunction.supportsPartial)
+
+ // Aggregation strategy can handle queries with a single distinct group and partial aggregates.
+ if (distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && existsNonPartial)) {
// Create the attributes for the grouping id and the group by clause.
- val gid =
- new AttributeReference("gid", IntegerType, false)(isGenerated = true)
+ val gid = AttributeReference("gid", IntegerType, nullable = false)(isGenerated = true)
val groupByMap = a.groupingExpressions.collect {
case ne: NamedExpression => ne -> ne.toAttribute
- case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)()
+ case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)()
}
val groupByAttrs = groupByMap.map(_._2)
@@ -135,9 +137,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
def patchAggregateFunctionChildren(
af: AggregateFunction)(
attrs: Expression => Expression): AggregateFunction = {
- af.withNewChildren(af.children.map {
- case afc => attrs(afc)
- }).asInstanceOf[AggregateFunction]
+ af.withNewChildren(af.children.map(attrs)).asInstanceOf[AggregateFunction]
}
// Setup unique distinct aggregate children.
@@ -265,5 +265,5 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
// NamedExpression. This is done to prevent collisions between distinct and regular aggregate
// children, in this case attribute reuse causes the input of the regular aggregate to bound to
// the (nulled out) input of the distinct aggregate.
- e -> new AttributeReference(e.sql, e.dataType, true)()
+ e -> AttributeReference(e.sql, e.dataType, nullable = true)()
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
new file mode 100644
index 0000000000..0b973c3b65
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
+import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{If, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
+import org.apache.spark.sql.types.{IntegerType, StringType}
+
+class RewriteDistinctAggregatesSuite extends PlanTest {
+ val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
+ val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
+ val analyzer = new Analyzer(catalog, conf)
+
+ val nullInt = Literal(null, IntegerType)
+ val nullString = Literal(null, StringType)
+ val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)
+
+ private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
+ case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
+ case _ => fail(s"Plan is not rewritten:\n$rewrite")
+ }
+
+ test("single distinct group") {
+ val input = testRelation
+ .groupBy('a)(countDistinct('e))
+ .analyze
+ val rewrite = RewriteDistinctAggregates(input)
+ comparePlans(input, rewrite)
+ }
+
+ test("single distinct group with partial aggregates") {
+ val input = testRelation
+ .groupBy('a, 'd)(
+ countDistinct('e, 'c).as('agg1),
+ max('b).as('agg2))
+ .analyze
+ val rewrite = RewriteDistinctAggregates(input)
+ comparePlans(input, rewrite)
+ }
+
+ test("single distinct group with non-partial aggregates") {
+ val input = testRelation
+ .groupBy('a, 'd)(
+ countDistinct('e, 'c).as('agg1),
+ CollectSet('b).toAggregateExpression().as('agg2))
+ .analyze
+ checkRewrite(RewriteDistinctAggregates(input))
+ }
+
+ test("multiple distinct groups") {
+ val input = testRelation
+ .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
+ .analyze
+ checkRewrite(RewriteDistinctAggregates(input))
+ }
+
+ test("multiple distinct groups with partial aggregates") {
+ val input = testRelation
+ .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
+ .analyze
+ checkRewrite(RewriteDistinctAggregates(input))
+ }
+
+ test("multiple distinct groups with non-partial aggregates") {
+ val input = testRelation
+ .groupBy('a)(
+ countDistinct('b, 'c),
+ countDistinct('d),
+ CollectSet('b).toAggregateExpression())
+ .analyze
+ checkRewrite(RewriteDistinctAggregates(input))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 427390a90f..0e172bee4f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -493,4 +493,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(1.5)),
Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(1.5))))
}
+
+ test("SPARK-17616: distinct aggregate combined with a non-partial aggregate") {
+ val df = Seq((1, 3, "a"), (1, 2, "b"), (3, 4, "c"), (3, 4, "c"), (3, 5, "d"))
+ .toDF("x", "y", "z")
+ checkAnswer(
+ df.groupBy($"x").agg(countDistinct($"y"), sort_array(collect_list($"z"))),
+ Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d"))))
+ }
}