diff options
author | Herman van Hovell <hvanhovell@questtec.nl> | 2015-11-07 13:37:37 -0800 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2015-11-07 13:37:37 -0800 |
commit | ef362846eb448769bcf774fc9090a5013d459464 (patch) | |
tree | 028176c2c4cd5fdc3c3c3a2bcaf2e3a9022b2c86 /sql | |
parent | 2ff0e79a8647cca5c9c57f613a07e739ac4f677e (diff) | |
download | spark-ef362846eb448769bcf774fc9090a5013d459464.tar.gz spark-ef362846eb448769bcf774fc9090a5013d459464.tar.bz2 spark-ef362846eb448769bcf774fc9090a5013d459464.zip |
[SPARK-9241][SQL] Supporting multiple DISTINCT columns - follow-up
This PR is a follow up for PR https://github.com/apache/spark/pull/9406. It adds more documentation to the rewriting rule, removes a redundant if expression in the non-distinct aggregation path and adds a multiple distinct test to the AggregationQuerySuite.
cc yhuai marmbrus
Author: Herman van Hovell <hvanhovell@questtec.nl>
Closes #9541 from hvanhovell/SPARK-9241-followup.
Diffstat (limited to 'sql')
2 files changed, 108 insertions, 23 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index 39010c3be6..ac23f72782 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -222,10 +222,76 @@ object Utils { * aggregation in which the regular aggregation expressions and every distinct clause is aggregated * in a separate group. The results are then combined in a second aggregate. * - * TODO Expression cannocalization - * TODO Eliminate foldable expressions from distinct clauses. - * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate - * operator. Perhaps this is a good thing? It is much simpler to plan later on... + * For example (in scala): + * {{{ + * val data = Seq( + * ("a", "ca1", "cb1", 10), + * ("a", "ca1", "cb2", 5), + * ("b", "ca1", "cb1", 13)) + * .toDF("key", "cat1", "cat2", "value") + * data.registerTempTable("data") + * + * val agg = data.groupBy($"key") + * .agg( + * countDistinct($"cat1").as("cat1_cnt"), + * countDistinct($"cat2").as("cat2_cnt"), + * sum($"value").as("total")) + * }}} + * + * This translates to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [COUNT(DISTINCT 'cat1), + * COUNT(DISTINCT 'cat2), + * sum('value)] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * LocalTableScan [...] + * }}} + * + * This rule rewrites this logical plan to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [count(if (('gid = 1)) 'cat1 else null), + * count(if (('gid = 2)) 'cat2 else null), + * first(if (('gid = 0)) 'total else null) ignore nulls] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * Aggregate( + * key = ['key, 'cat1, 'cat2, 'gid] + * functions = [sum('value)] + * output = ['key, 'cat1, 'cat2, 'gid, 'total]) + * Expand( + * projections = [('key, null, null, 0, cast('value as bigint)), + * ('key, 'cat1, null, 1, null), + * ('key, null, 'cat2, 2, null)] + * output = ['key, 'cat1, 'cat2, 'gid, 'value]) + * LocalTableScan [...] + * }}} + * + * The rule does the following things here: + * 1. Expand the data. There are three aggregation groups in this query: + * i. the non-distinct group; + * ii. the distinct 'cat1 group; + * iii. the distinct 'cat2 group. + * An expand operator is inserted to expand the child data for each group. The expand will null + * out all unused columns for the given group; this must be done in order to ensure correctness + * later on. Groups can by identified by a group id (gid) column added by the expand operator. + * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of + * this aggregate consists of the original group by clause, all the requested distinct columns + * and the group id. Both de-duplication of distinct column and the aggregation of the + * non-distinct group take advantage of the fact that we group by the group id (gid) and that we + * have nulled out all non-relevant columns for the the given group. + * 3. Aggregating the distinct groups and combining this with the results of the non-distinct + * aggregation. In this step we use the group id to filter the inputs for the aggregate + * functions. The result of the non-distinct group are 'aggregated' by using the first operator, + * it might be more elegant to use the native UDAF merge mechanism for this in the future. + * + * This rule duplicates the input data by two or more times (# distinct groups + an optional + * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and + * exchange operators. Keeping the number of distinct groups as low a possible should be priority, + * we could improve this in the current rule by applying more advanced expression cannocalization + * techniques. */ object MultipleDistinctRewriter extends Rule[LogicalPlan] { @@ -261,11 +327,10 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Functions used to modify aggregate functions and their inputs. def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) def patchAggregateFunctionChildren( - af: AggregateFunction2, - id: Literal, - attrs: Map[Expression, Expression]): AggregateFunction2 = { - af.withNewChildren(af.children.map { case afc => - evalWithinGroup(id, attrs(afc)) + af: AggregateFunction2)( + attrs: Expression => Expression): AggregateFunction2 = { + af.withNewChildren(af.children.map { + case afc => attrs(afc) }).asInstanceOf[AggregateFunction2] } @@ -288,7 +353,9 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Final aggregate val operators = expressions.map { e => val af = e.aggregateFunction - val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap) + val naf = patchAggregateFunctionChildren(af) { x => + evalWithinGroup(id, distinctAggChildAttrMap(x)) + } (e, e.copy(aggregateFunction = naf, isDistinct = false)) } @@ -304,26 +371,27 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { val regularGroupId = Literal(0) val regularAggOperatorMap = regularAggExprs.map { e => // Perform the actual aggregation in the initial aggregate. - val af = patchAggregateFunctionChildren( - e.aggregateFunction, - regularGroupId, - regularAggChildAttrMap) - val a = Alias(e.copy(aggregateFunction = af), e.toString)() - - // Get the result of the first aggregate in the last aggregate. - val b = AggregateExpression2( - aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)), + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap) + val operator = Alias(e.copy(aggregateFunction = af), e.toString)() + + // Select the result of the first aggregate in the last aggregate. + val result = AggregateExpression2( + aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), mode = Complete, isDistinct = false) // Some aggregate functions (COUNT) have the special property that they can return a // non-null result without any input. We need to make sure we return a result in this case. - val c = af.defaultResult match { - case Some(lit) => Coalesce(Seq(b, lit)) - case None => b + val resultWithDefault = af.defaultResult match { + case Some(lit) => Coalesce(Seq(result, lit)) + case None => result } - (e, a, c) + // Return a Tuple3 containing: + // i. The original aggregate expression (used for look ups). + // ii. The actual aggregation operator (used in the first aggregate). + // iii. The operator that selects and returns the result (used in the second aggregate). + (e, operator, resultWithDefault) } // Construct the regular aggregate input projection only if we need one. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index ea80060e37..7f6fe33923 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -516,6 +516,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(3, 4, 4, 3, null) :: Nil) } + test("multiple distinct column sets") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | count(distinct value1), + | count(distinct value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(null, 3, 3) :: + Row(1, 2, 3) :: + Row(2, 2, 1) :: + Row(3, 0, 1) :: Nil) + } + test("test count") { checkAnswer( sqlContext.sql( |