aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2015-11-29 14:13:11 -0800
committerYin Huai <yhuai@databricks.com>2015-11-29 14:13:11 -0800
commit3d28081e53698ed77e93c04299957c02bcaba9bf (patch)
tree8c4c791b93a06e975a31b140784c38bc6980b303 /sql/core
parentcc7a1bc9370b163f51230e5ca4be612d133a5086 (diff)
downloadspark-3d28081e53698ed77e93c04299957c02bcaba9bf.tar.gz
spark-3d28081e53698ed77e93c04299957c02bcaba9bf.tar.bz2
spark-3d28081e53698ed77e93c04299957c02bcaba9bf.zip
[SPARK-12024][SQL] More efficient multi-column counting.
In https://github.com/apache/spark/pull/9409 we enabled multi-column counting. The approach taken in that PR introduces a bit of overhead by first creating a row only to check if all of the columns are non-null. This PR fixes that technical debt. Count now takes multiple columns as its input. In order to make this work I have also added support for multiple columns in the single distinct code path. cc yhuai Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #10015 from hvanhovell/SPARK-12024.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala39
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala4
2 files changed, 21 insertions, 22 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index a70e41436c..76b938cdb6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -146,20 +146,16 @@ object Utils {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
// functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one
- // DISTINCT aggregate function, all of those functions will have the same column expression.
+ // DISTINCT aggregate function, all of those functions will have the same column expressions.
// For example, it would be valid for functionsWithDistinct to be
// [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is
// disallowed because those two distinct aggregates have different column expressions.
- val distinctColumnExpression: Expression = {
- val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
- assert(allDistinctColumnExpressions.length == 1)
- allDistinctColumnExpressions.head
- }
- val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match {
+ val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
+ val namedDistinctColumnExpressions = distinctColumnExpressions.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}
- val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute
+ val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute)
val groupingAttributes = groupingExpressions.map(_.toAttribute)
// 1. Create an Aggregate Operator for partial aggregations.
@@ -170,10 +166,11 @@ object Utils {
// We will group by the original grouping expression, plus an additional expression for the
// DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
// expressions will be [key, value].
- val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression
+ val partialAggregateGroupingExpressions =
+ groupingExpressions ++ namedDistinctColumnExpressions
val partialAggregateResult =
groupingAttributes ++
- Seq(distinctColumnAttribute) ++
+ distinctColumnAttributes ++
partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
if (usesTungstenAggregate) {
TungstenAggregate(
@@ -208,28 +205,28 @@ object Utils {
partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
val partialMergeAggregateResult =
groupingAttributes ++
- Seq(distinctColumnAttribute) ++
+ distinctColumnAttributes ++
partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
if (usesTungstenAggregate) {
TungstenAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
+ groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
- initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = partialMergeAggregateResult,
child = partialAggregate)
} else {
SortBasedAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
+ groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
- initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = partialMergeAggregateResult,
child = partialAggregate)
}
@@ -244,14 +241,16 @@ object Utils {
expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
}
+ val distinctColumnAttributeLookup =
+ distinctColumnExpressions.zip(distinctColumnAttributes).toMap
val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
// to AttributeReferences.
case agg @ AggregateExpression(aggregateFunction, mode, true) =>
- val rewrittenAggregateFunction = aggregateFunction.transformDown {
- case expr if expr == distinctColumnExpression => distinctColumnAttribute
- }.asInstanceOf[AggregateFunction]
+ val rewrittenAggregateFunction = aggregateFunction
+ .transformDown(distinctColumnAttributeLookup)
+ .asInstanceOf[AggregateFunction]
// We rewrite the aggregate function to a non-distinct aggregation because
// its input will have distinct arguments.
// We just keep the isDistinct setting to true, so when users look at the query plan,
@@ -270,7 +269,7 @@ object Utils {
nonCompleteAggregateAttributes = finalAggregateAttributes,
completeAggregateExpressions = completeAggregateExpressions,
completeAggregateAttributes = completeAggregateAttributes,
- initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = resultExpressions,
child = partialMergeAggregate)
} else {
@@ -281,7 +280,7 @@ object Utils {
nonCompleteAggregateAttributes = finalAggregateAttributes,
completeAggregateExpressions = completeAggregateExpressions,
completeAggregateAttributes = completeAggregateAttributes,
- initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = resultExpressions,
child = partialMergeAggregate)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
index fc873c04f8..893e800a61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
@@ -152,8 +152,8 @@ class WindowSpec private[sql](
case Sum(child) => WindowExpression(
UnresolvedWindowFunction("sum", child :: Nil),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
- case Count(child) => WindowExpression(
- UnresolvedWindowFunction("count", child :: Nil),
+ case Count(children) => WindowExpression(
+ UnresolvedWindowFunction("count", children),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case First(child, ignoreNulls) => WindowExpression(
// TODO this is a hack for Hive UDAF first_value