aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2015-11-07 13:37:37 -0800
committerYin Huai <yhuai@databricks.com>2015-11-07 13:37:37 -0800
commitef362846eb448769bcf774fc9090a5013d459464 (patch)
tree028176c2c4cd5fdc3c3c3a2bcaf2e3a9022b2c86
parent2ff0e79a8647cca5c9c57f613a07e739ac4f677e (diff)
downloadspark-ef362846eb448769bcf774fc9090a5013d459464.tar.gz
spark-ef362846eb448769bcf774fc9090a5013d459464.tar.bz2
spark-ef362846eb448769bcf774fc9090a5013d459464.zip
[SPARK-9241][SQL] Supporting multiple DISTINCT columns - follow-up
This PR is a follow up for PR https://github.com/apache/spark/pull/9406. It adds more documentation to the rewriting rule, removes a redundant if expression in the non-distinct aggregation path and adds a multiple distinct test to the AggregationQuerySuite. cc yhuai marmbrus Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #9541 from hvanhovell/SPARK-9241-followup.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala114
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala17
2 files changed, 108 insertions, 23 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
index 39010c3be6..ac23f72782 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
@@ -222,10 +222,76 @@ object Utils {
* aggregation in which the regular aggregation expressions and every distinct clause is aggregated
* in a separate group. The results are then combined in a second aggregate.
*
- * TODO Expression cannocalization
- * TODO Eliminate foldable expressions from distinct clauses.
- * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate
- * operator. Perhaps this is a good thing? It is much simpler to plan later on...
+ * For example (in scala):
+ * {{{
+ * val data = Seq(
+ * ("a", "ca1", "cb1", 10),
+ * ("a", "ca1", "cb2", 5),
+ * ("b", "ca1", "cb1", 13))
+ * .toDF("key", "cat1", "cat2", "value")
+ * data.registerTempTable("data")
+ *
+ * val agg = data.groupBy($"key")
+ * .agg(
+ * countDistinct($"cat1").as("cat1_cnt"),
+ * countDistinct($"cat2").as("cat2_cnt"),
+ * sum($"value").as("total"))
+ * }}}
+ *
+ * This translates to the following (pseudo) logical plan:
+ * {{{
+ * Aggregate(
+ * key = ['key]
+ * functions = [COUNT(DISTINCT 'cat1),
+ * COUNT(DISTINCT 'cat2),
+ * sum('value)]
+ * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
+ * LocalTableScan [...]
+ * }}}
+ *
+ * This rule rewrites this logical plan to the following (pseudo) logical plan:
+ * {{{
+ * Aggregate(
+ * key = ['key]
+ * functions = [count(if (('gid = 1)) 'cat1 else null),
+ * count(if (('gid = 2)) 'cat2 else null),
+ * first(if (('gid = 0)) 'total else null) ignore nulls]
+ * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
+ * Aggregate(
+ * key = ['key, 'cat1, 'cat2, 'gid]
+ * functions = [sum('value)]
+ * output = ['key, 'cat1, 'cat2, 'gid, 'total])
+ * Expand(
+ * projections = [('key, null, null, 0, cast('value as bigint)),
+ * ('key, 'cat1, null, 1, null),
+ * ('key, null, 'cat2, 2, null)]
+ * output = ['key, 'cat1, 'cat2, 'gid, 'value])
+ * LocalTableScan [...]
+ * }}}
+ *
+ * The rule does the following things here:
+ * 1. Expand the data. There are three aggregation groups in this query:
+ * i. the non-distinct group;
+ * ii. the distinct 'cat1 group;
+ * iii. the distinct 'cat2 group.
+ * An expand operator is inserted to expand the child data for each group. The expand will null
+ * out all unused columns for the given group; this must be done in order to ensure correctness
+ * later on. Groups can by identified by a group id (gid) column added by the expand operator.
+ * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of
+ * this aggregate consists of the original group by clause, all the requested distinct columns
+ * and the group id. Both de-duplication of distinct column and the aggregation of the
+ * non-distinct group take advantage of the fact that we group by the group id (gid) and that we
+ * have nulled out all non-relevant columns for the the given group.
+ * 3. Aggregating the distinct groups and combining this with the results of the non-distinct
+ * aggregation. In this step we use the group id to filter the inputs for the aggregate
+ * functions. The result of the non-distinct group are 'aggregated' by using the first operator,
+ * it might be more elegant to use the native UDAF merge mechanism for this in the future.
+ *
+ * This rule duplicates the input data by two or more times (# distinct groups + an optional
+ * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and
+ * exchange operators. Keeping the number of distinct groups as low a possible should be priority,
+ * we could improve this in the current rule by applying more advanced expression cannocalization
+ * techniques.
*/
object MultipleDistinctRewriter extends Rule[LogicalPlan] {
@@ -261,11 +327,10 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// Functions used to modify aggregate functions and their inputs.
def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
def patchAggregateFunctionChildren(
- af: AggregateFunction2,
- id: Literal,
- attrs: Map[Expression, Expression]): AggregateFunction2 = {
- af.withNewChildren(af.children.map { case afc =>
- evalWithinGroup(id, attrs(afc))
+ af: AggregateFunction2)(
+ attrs: Expression => Expression): AggregateFunction2 = {
+ af.withNewChildren(af.children.map {
+ case afc => attrs(afc)
}).asInstanceOf[AggregateFunction2]
}
@@ -288,7 +353,9 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// Final aggregate
val operators = expressions.map { e =>
val af = e.aggregateFunction
- val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap)
+ val naf = patchAggregateFunctionChildren(af) { x =>
+ evalWithinGroup(id, distinctAggChildAttrMap(x))
+ }
(e, e.copy(aggregateFunction = naf, isDistinct = false))
}
@@ -304,26 +371,27 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
val regularGroupId = Literal(0)
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
- val af = patchAggregateFunctionChildren(
- e.aggregateFunction,
- regularGroupId,
- regularAggChildAttrMap)
- val a = Alias(e.copy(aggregateFunction = af), e.toString)()
-
- // Get the result of the first aggregate in the last aggregate.
- val b = AggregateExpression2(
- aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)),
+ val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap)
+ val operator = Alias(e.copy(aggregateFunction = af), e.toString)()
+
+ // Select the result of the first aggregate in the last aggregate.
+ val result = AggregateExpression2(
+ aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)),
mode = Complete,
isDistinct = false)
// Some aggregate functions (COUNT) have the special property that they can return a
// non-null result without any input. We need to make sure we return a result in this case.
- val c = af.defaultResult match {
- case Some(lit) => Coalesce(Seq(b, lit))
- case None => b
+ val resultWithDefault = af.defaultResult match {
+ case Some(lit) => Coalesce(Seq(result, lit))
+ case None => result
}
- (e, a, c)
+ // Return a Tuple3 containing:
+ // i. The original aggregate expression (used for look ups).
+ // ii. The actual aggregation operator (used in the first aggregate).
+ // iii. The operator that selects and returns the result (used in the second aggregate).
+ (e, operator, resultWithDefault)
}
// Construct the regular aggregate input projection only if we need one.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index ea80060e37..7f6fe33923 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -516,6 +516,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(3, 4, 4, 3, null) :: Nil)
}
+ test("multiple distinct column sets") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | count(distinct value1),
+ | count(distinct value2)
+ |FROM agg2
+ |GROUP BY key
+ """.stripMargin),
+ Row(null, 3, 3) ::
+ Row(1, 2, 3) ::
+ Row(2, 2, 1) ::
+ Row(3, 0, 1) :: Nil)
+ }
+
test("test count") {
checkAnswer(
sqlContext.sql(