aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-09-01 13:19:15 +0800
committerWenchen Fan <wenchen@databricks.com>2016-09-01 13:19:15 +0800
commitaaf632b2132750c697dddd0469b902d9308dbf36 (patch)
tree45f8c6d5d852f2ec8ad8b100969c482b18a8b68f
parent7a5000f39ef4f195696836f8a4e8ab4ff5c14dd2 (diff)
downloadspark-aaf632b2132750c697dddd0469b902d9308dbf36.tar.gz
spark-aaf632b2132750c697dddd0469b902d9308dbf36.tar.bz2
spark-aaf632b2132750c697dddd0469b902d9308dbf36.zip
revert PR#10896 and PR#14865
## What changes were proposed in this pull request? according to the discussion in the original PR #10896 and the new approach PR #14876 , we decided to revert these 2 PRs and go with the new approach. ## How was this patch tested? N/A Author: Wenchen Fan <wenchen@databricks.com> Closes #14909 from cloud-fan/revert.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala250
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala56
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala39
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala77
8 files changed, 223 insertions, 277 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index cda3b2b75e..4aaf454285 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -259,17 +259,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
val aggregateOperator =
- if (functionsWithDistinct.isEmpty) {
+ if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
+ if (functionsWithDistinct.nonEmpty) {
+ sys.error("Distinct columns cannot exist in Aggregate operator containing " +
+ "aggregate functions which don't support partial aggregation.")
+ } else {
+ aggregate.AggUtils.planAggregateWithoutPartial(
+ groupingExpressions,
+ aggregateExpressions,
+ resultExpressions,
+ planLater(child))
+ }
+ } else if (functionsWithDistinct.isEmpty) {
aggregate.AggUtils.planAggregateWithoutDistinct(
groupingExpressions,
aggregateExpressions,
resultExpressions,
planLater(child))
} else {
- if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
- sys.error("Distinct columns cannot exist in Aggregate operator containing " +
- "aggregate functions which don't support partial aggregation.")
- }
aggregate.AggUtils.planAggregateWithOneDistinct(
groupingExpressions,
functionsWithDistinct,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index fe75ecea17..4fbb9d554c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -19,97 +19,34 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec}
/**
- * A pattern that finds aggregate operators to support partial aggregations.
- */
-object PartialAggregate {
-
- def unapply(plan: SparkPlan): Option[Distribution] = plan match {
- case agg: AggregateExec if AggUtils.supportPartialAggregate(agg.aggregateExpressions) =>
- Some(agg.requiredChildDistribution.head)
- case _ =>
- None
- }
-}
-
-/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
*/
object AggUtils {
- def supportPartialAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = {
- aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial)
- }
-
- private def createPartialAggregateExec(
+ def planAggregateWithoutPartial(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
- child: SparkPlan): SparkPlan = {
- val groupingAttributes = groupingExpressions.map(_.toAttribute)
- val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
- val partialAggregateExpressions = aggregateExpressions.map {
- case agg @ AggregateExpression(_, _, false, _) if functionsWithDistinct.length > 0 =>
- agg.copy(mode = PartialMerge)
- case agg =>
- agg.copy(mode = Partial)
- }
- val partialAggregateAttributes =
- partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
- val partialResultExpressions =
- groupingAttributes ++
- partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan): Seq[SparkPlan] = {
- createAggregateExec(
- requiredChildDistributionExpressions = None,
+ val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
+ val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute)
+ SortAggregateExec(
+ requiredChildDistributionExpressions = Some(groupingExpressions),
groupingExpressions = groupingExpressions,
- aggregateExpressions = partialAggregateExpressions,
- aggregateAttributes = partialAggregateAttributes,
- initialInputBufferOffset = if (functionsWithDistinct.length > 0) {
- groupingExpressions.length + functionsWithDistinct.head.aggregateFunction.children.length
- } else {
- 0
- },
- resultExpressions = partialResultExpressions,
- child = child)
- }
-
- private def updateMergeAggregateMode(aggregateExpressions: Seq[AggregateExpression]) = {
- def updateMode(mode: AggregateMode) = mode match {
- case Partial => PartialMerge
- case Complete => Final
- case mode => mode
- }
- aggregateExpressions.map(e => e.copy(mode = updateMode(e.mode)))
- }
-
- /**
- * Builds new merge and map-side [[AggregateExec]]s from an input aggregate operator.
- * If an aggregation needs a shuffle for satisfying its own distribution and supports partial
- * aggregations, a map-side aggregation is appended before the shuffle in
- * [[org.apache.spark.sql.execution.exchange.EnsureRequirements]].
- */
- def createMapMergeAggregatePair(operator: SparkPlan): (SparkPlan, SparkPlan) = operator match {
- case agg: AggregateExec =>
- val mapSideAgg = createPartialAggregateExec(
- agg.groupingExpressions, agg.aggregateExpressions, agg.child)
- val mergeAgg = createAggregateExec(
- requiredChildDistributionExpressions = agg.requiredChildDistributionExpressions,
- groupingExpressions = agg.groupingExpressions.map(_.toAttribute),
- aggregateExpressions = updateMergeAggregateMode(agg.aggregateExpressions),
- aggregateAttributes = agg.aggregateAttributes,
- initialInputBufferOffset = agg.groupingExpressions.length,
- resultExpressions = agg.resultExpressions,
- child = mapSideAgg
- )
-
- (mergeAgg, mapSideAgg)
+ aggregateExpressions = completeAggregateExpressions,
+ aggregateAttributes = completeAggregateAttributes,
+ initialInputBufferOffset = 0,
+ resultExpressions = resultExpressions,
+ child = child
+ ) :: Nil
}
- private def createAggregateExec(
+ private def createAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]] = None,
groupingExpressions: Seq[NamedExpression] = Nil,
aggregateExpressions: Seq[AggregateExpression] = Nil,
@@ -118,8 +55,7 @@ object AggUtils {
resultExpressions: Seq[NamedExpression] = Nil,
child: SparkPlan): SparkPlan = {
val useHash = HashAggregateExec.supportsAggregate(
- aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) &&
- supportPartialAggregate(aggregateExpressions)
+ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
if (useHash) {
HashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
@@ -146,21 +82,43 @@ object AggUtils {
aggregateExpressions: Seq[AggregateExpression],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
+ // Check if we can use HashAggregate.
+
+ // 1. Create an Aggregate Operator for partial aggregations.
+
val groupingAttributes = groupingExpressions.map(_.toAttribute)
- val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
- val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute)
- val supportPartial = supportPartialAggregate(aggregateExpressions)
+ val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial))
+ val partialAggregateAttributes =
+ partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+ val partialResultExpressions =
+ groupingAttributes ++
+ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
- createAggregateExec(
- requiredChildDistributionExpressions =
- Some(if (supportPartial) groupingAttributes else groupingExpressions),
- groupingExpressions = groupingExpressions,
- aggregateExpressions = completeAggregateExpressions,
- aggregateAttributes = completeAggregateAttributes,
- initialInputBufferOffset = 0,
- resultExpressions = resultExpressions,
- child = child
- ) :: Nil
+ val partialAggregate = createAggregate(
+ requiredChildDistributionExpressions = None,
+ groupingExpressions = groupingExpressions,
+ aggregateExpressions = partialAggregateExpressions,
+ aggregateAttributes = partialAggregateAttributes,
+ initialInputBufferOffset = 0,
+ resultExpressions = partialResultExpressions,
+ child = child)
+
+ // 2. Create an Aggregate Operator for final aggregations.
+ val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
+ // The attributes of the final aggregation buffer, which is presented as input to the result
+ // projection:
+ val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
+
+ val finalAggregate = createAggregate(
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
+ aggregateExpressions = finalAggregateExpressions,
+ aggregateAttributes = finalAggregateAttributes,
+ initialInputBufferOffset = groupingExpressions.length,
+ resultExpressions = resultExpressions,
+ child = partialAggregate)
+
+ finalAggregate :: Nil
}
def planAggregateWithOneDistinct(
@@ -183,23 +141,39 @@ object AggUtils {
val distinctAttributes = namedDistinctExpressions.map(_.toAttribute)
val groupingAttributes = groupingExpressions.map(_.toAttribute)
- // 1. Create an Aggregate Operator for non-distinct aggregations.
+ // 1. Create an Aggregate Operator for partial aggregations.
val partialAggregate: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
- createAggregateExec(
+ // We will group by the original grouping expression, plus an additional expression for the
+ // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
+ // expressions will be [key, value].
+ createAggregate(
+ groupingExpressions = groupingExpressions ++ namedDistinctExpressions,
+ aggregateExpressions = aggregateExpressions,
+ aggregateAttributes = aggregateAttributes,
+ resultExpressions = groupingAttributes ++ distinctAttributes ++
+ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+ child = child)
+ }
+
+ // 2. Create an Aggregate Operator for partial merge aggregations.
+ val partialMergeAggregate: SparkPlan = {
+ val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+ val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+ createAggregate(
requiredChildDistributionExpressions =
Some(groupingAttributes ++ distinctAttributes),
- groupingExpressions = groupingExpressions ++ namedDistinctExpressions,
+ groupingExpressions = groupingAttributes ++ distinctAttributes,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length,
resultExpressions = groupingAttributes ++ distinctAttributes ++
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
- child = child)
+ child = partialAggregate)
}
- // 2. Create an Aggregate Operator for the final aggregation.
+ // 3. Create an Aggregate operator for partial aggregation (for distinct)
val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap
val rewrittenDistinctFunctions = functionsWithDistinct.map {
// Children of an AggregateFunction with DISTINCT keyword has already
@@ -209,6 +183,38 @@ object AggUtils {
aggregateFunction.transformDown(distinctColumnAttributeLookup)
.asInstanceOf[AggregateFunction]
}
+
+ val partialDistinctAggregate: SparkPlan = {
+ val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+ // The attributes of the final aggregation buffer, which is presented as input to the result
+ // projection:
+ val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute)
+ val (distinctAggregateExpressions, distinctAggregateAttributes) =
+ rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
+ // We rewrite the aggregate function to a non-distinct aggregation because
+ // its input will have distinct arguments.
+ // We just keep the isDistinct setting to true, so when users look at the query plan,
+ // they still can see distinct aggregations.
+ val expr = AggregateExpression(func, Partial, isDistinct = true)
+ // Use original AggregationFunction to lookup attributes, which is used to build
+ // aggregateFunctionToAttribute
+ val attr = functionsWithDistinct(i).resultAttribute
+ (expr, attr)
+ }.unzip
+
+ val partialAggregateResult = groupingAttributes ++
+ mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++
+ distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+ createAggregate(
+ groupingExpressions = groupingAttributes,
+ aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions,
+ aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes,
+ initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length,
+ resultExpressions = partialAggregateResult,
+ child = partialMergeAggregate)
+ }
+
+ // 4. Create an Aggregate Operator for the final aggregation.
val finalAndCompleteAggregate: SparkPlan = {
val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
// The attributes of the final aggregation buffer, which is presented as input to the result
@@ -219,23 +225,23 @@ object AggUtils {
rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
// We rewrite the aggregate function to a non-distinct aggregation because
// its input will have distinct arguments.
- // We keep the isDistinct setting to true because this flag is used to generate partial
- // aggregations and it is easy to see aggregation types in the query plan.
- val expr = AggregateExpression(func, Complete, isDistinct = true)
+ // We just keep the isDistinct setting to true, so when users look at the query plan,
+ // they still can see distinct aggregations.
+ val expr = AggregateExpression(func, Final, isDistinct = true)
// Use original AggregationFunction to lookup attributes, which is used to build
// aggregateFunctionToAttribute
val attr = functionsWithDistinct(i).resultAttribute
(expr, attr)
- }.unzip
+ }.unzip
- createAggregateExec(
+ createAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
groupingExpressions = groupingAttributes,
aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions,
aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes,
initialInputBufferOffset = groupingAttributes.length,
resultExpressions = resultExpressions,
- child = partialAggregate)
+ child = partialDistinctAggregate)
}
finalAndCompleteAggregate :: Nil
@@ -243,14 +249,13 @@ object AggUtils {
/**
* Plans a streaming aggregation using the following progression:
- * - Partial Aggregation (now there is at most 1 tuple per group)
+ * - Partial Aggregation
+ * - Shuffle
+ * - Partial Merge (now there is at most 1 tuple per group)
* - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous)
* - PartialMerge (now there is at most 1 tuple per group)
* - StateStoreSave (saves the tuple for the next batch)
* - Complete (output the current result of the aggregation)
- *
- * If the first aggregation needs a shuffle to satisfy its distribution, a map-side partial
- * an aggregation and a shuffle are added in `EnsureRequirements`.
*/
def planStreamingAggregation(
groupingExpressions: Seq[NamedExpression],
@@ -263,24 +268,39 @@ object AggUtils {
val partialAggregate: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
- createAggregateExec(
+ // We will group by the original grouping expression, plus an additional expression for the
+ // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
+ // expressions will be [key, value].
+ createAggregate(
+ groupingExpressions = groupingExpressions,
+ aggregateExpressions = aggregateExpressions,
+ aggregateAttributes = aggregateAttributes,
+ resultExpressions = groupingAttributes ++
+ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+ child = child)
+ }
+
+ val partialMerged1: SparkPlan = {
+ val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+ val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+ createAggregate(
requiredChildDistributionExpressions =
Some(groupingAttributes),
- groupingExpressions = groupingExpressions,
+ groupingExpressions = groupingAttributes,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = groupingAttributes.length,
resultExpressions = groupingAttributes ++
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
- child = child)
+ child = partialAggregate)
}
- val restored = StateStoreRestoreExec(groupingAttributes, None, partialAggregate)
+ val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1)
- val partialMerged: SparkPlan = {
+ val partialMerged2: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
- createAggregateExec(
+ createAggregate(
requiredChildDistributionExpressions =
Some(groupingAttributes),
groupingExpressions = groupingAttributes,
@@ -294,7 +314,7 @@ object AggUtils {
// Note: stateId and returnAllStates are filled in later with preparation rules
// in IncrementalExecution.
val saved = StateStoreSaveExec(
- groupingAttributes, stateId = None, returnAllStates = None, partialMerged)
+ groupingAttributes, stateId = None, returnAllStates = None, partialMerged2)
val finalAndCompleteAggregate: SparkPlan = {
val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
@@ -302,7 +322,7 @@ object AggUtils {
// projection:
val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
- createAggregateExec(
+ createAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
groupingExpressions = groupingAttributes,
aggregateExpressions = finalAggregateExpressions,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala
deleted file mode 100644
index b88a8aa3da..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.UnaryExecNode
-
-/**
- * A base class for aggregate implementation.
- */
-abstract class AggregateExec extends UnaryExecNode {
-
- def requiredChildDistributionExpressions: Option[Seq[Expression]]
- def groupingExpressions: Seq[NamedExpression]
- def aggregateExpressions: Seq[AggregateExpression]
- def aggregateAttributes: Seq[Attribute]
- def initialInputBufferOffset: Int
- def resultExpressions: Seq[NamedExpression]
-
- protected[this] val aggregateBufferAttributes = {
- aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
- }
-
- override def producedAttributes: AttributeSet =
- AttributeSet(aggregateAttributes) ++
- AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
- AttributeSet(aggregateBufferAttributes)
-
- override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
- override def requiredChildDistribution: List[Distribution] = {
- requiredChildDistributionExpressions match {
- case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
- case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
- case None => UnspecifiedDistribution :: Nil
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 525c7e301a..bd7efa606e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
@@ -41,7 +42,11 @@ case class HashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
- extends AggregateExec with CodegenSupport {
+ extends UnaryExecNode with CodegenSupport {
+
+ private[this] val aggregateBufferAttributes = {
+ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+ }
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
@@ -55,6 +60,21 @@ case class HashAggregateExec(
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ override def producedAttributes: AttributeSet =
+ AttributeSet(aggregateAttributes) ++
+ AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
+ AttributeSet(aggregateBufferAttributes)
+
+ override def requiredChildDistribution: List[Distribution] = {
+ requiredChildDistributionExpressions match {
+ case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
+ case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
+ case None => UnspecifiedDistribution :: Nil
+ }
+ }
+
// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
// map and/or the sort-based aggregation once it has processed a given number of input rows.
private val testFallbackStartsAt: Option[(Int, Int)] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
index 68f86fca80..2a81a823c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.Utils
@@ -37,11 +38,30 @@ case class SortAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
- extends AggregateExec {
+ extends UnaryExecNode {
+
+ private[this] val aggregateBufferAttributes = {
+ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+ }
+
+ override def producedAttributes: AttributeSet =
+ AttributeSet(aggregateAttributes) ++
+ AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
+ AttributeSet(aggregateBufferAttributes)
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ override def requiredChildDistribution: List[Distribution] = {
+ requiredChildDistributionExpressions match {
+ case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
+ case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
+ case None => UnspecifiedDistribution :: Nil
+ }
+ }
+
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 66e99ded24..f17049949a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -21,8 +21,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.aggregate.AggUtils
-import org.apache.spark.sql.execution.aggregate.PartialAggregate
import org.apache.spark.sql.internal.SQLConf
/**
@@ -153,31 +151,18 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
- assert(requiredChildDistributions.length == operator.children.length)
- assert(requiredChildOrderings.length == operator.children.length)
+ var children: Seq[SparkPlan] = operator.children
+ assert(requiredChildDistributions.length == children.length)
+ assert(requiredChildOrderings.length == children.length)
- def createShuffleExchange(dist: Distribution, child: SparkPlan) =
- ShuffleExchange(createPartitioning(dist, defaultNumPreShufflePartitions), child)
-
- var (parent, children) = operator match {
- case PartialAggregate(childDist) if !operator.outputPartitioning.satisfies(childDist) =>
- // If an aggregation needs a shuffle and support partial aggregations, a map-side partial
- // aggregation and a shuffle are added as children.
- val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator)
- (mergeAgg, createShuffleExchange(
- requiredChildDistributions.head, ensureDistributionAndOrdering(mapSideAgg)) :: Nil)
- case _ =>
- // Ensure that the operator's children satisfy their output distribution requirements:
- val childrenWithDist = operator.children.zip(requiredChildDistributions)
- val newChildren = childrenWithDist.map {
- case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
- child
- case (child, BroadcastDistribution(mode)) =>
- BroadcastExchangeExec(mode, child)
- case (child, distribution) =>
- createShuffleExchange(distribution, child)
- }
- (operator, newChildren)
+ // Ensure that the operator's children satisfy their output distribution requirements:
+ children = children.zip(requiredChildDistributions).map {
+ case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
+ child
+ case (child, BroadcastDistribution(mode)) =>
+ BroadcastExchangeExec(mode, child)
+ case (child, distribution) =>
+ ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
}
// If the operator has multiple children and specifies child output distributions (e.g. join),
@@ -270,7 +255,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
}
}
- parent.withNewChildren(children)
+ operator.withNewChildren(children)
}
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index ce0b92a461..f89951760f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1248,17 +1248,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
/**
- * Verifies that there is a single Aggregation for `df`
+ * Verifies that there is no Exchange between the Aggregations for `df`
*/
- private def verifyNonExchangingSingleAgg(df: DataFrame) = {
+ private def verifyNonExchangingAgg(df: DataFrame) = {
var atFirstAgg: Boolean = false
df.queryExecution.executedPlan.foreach {
case agg: HashAggregateExec =>
+ atFirstAgg = !atFirstAgg
+ case _ =>
if (atFirstAgg) {
- fail("Should not have back to back Aggregates")
+ fail("Should not have operators between the two aggregations")
}
- atFirstAgg = true
- case _ =>
}
}
@@ -1292,10 +1292,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
// Group by the column we are distributed by. This should generate a plan with no exchange
// between the aggregates
val df3 = testData.repartition($"key").groupBy("key").count()
- verifyNonExchangingSingleAgg(df3)
- verifyNonExchangingSingleAgg(testData.repartition($"key", $"value")
+ verifyNonExchangingAgg(df3)
+ verifyNonExchangingAgg(testData.repartition($"key", $"value")
.groupBy("key", "value").count())
- verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key", "value").count())
// Grouping by just the first distributeBy expr, need to exchange.
verifyExchangingAgg(testData.repartition($"key", $"value")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index b0aa3378e5..375da224aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.aggregate.SortAggregateExec
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
@@ -38,84 +37,36 @@ class PlannerSuite extends SharedSQLContext {
setupTestData()
- private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] = {
+ private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
val planner = spark.sessionState.planner
import planner._
- val ensureRequirements = EnsureRequirements(spark.sessionState.conf)
- val planned = Aggregation(query).headOption.map(ensureRequirements(_))
- .getOrElse(fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
- planned.collect { case n if n.nodeName contains "Aggregate" => n }
+ val plannedOption = Aggregation(query).headOption
+ val planned =
+ plannedOption.getOrElse(
+ fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
+ val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
+
+ // For the new aggregation code path, there will be four aggregate operator for
+ // distinct aggregations.
+ assert(
+ aggregations.size == 2 || aggregations.size == 4,
+ s"The plan of query $query does not have partial aggregations.")
}
test("count is partially aggregated") {
val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
- assert(testPartialAggregationPlan(query).size == 2,
- s"The plan of query $query does not have partial aggregations.")
+ testPartialAggregationPlan(query)
}
test("count distinct is partially aggregated") {
val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
testPartialAggregationPlan(query)
- // For the new aggregation code path, there will be four aggregate operator for distinct
- // aggregations.
- assert(testPartialAggregationPlan(query).size == 4,
- s"The plan of query $query does not have partial aggregations.")
}
test("mixed aggregates are partially aggregated") {
val query =
testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
- // For the new aggregation code path, there will be four aggregate operator for distinct
- // aggregations.
- assert(testPartialAggregationPlan(query).size == 4,
- s"The plan of query $query does not have partial aggregations.")
- }
-
- test("SPARK-17289 sort-based partial aggregation needs a sort operator as a child") {
- withTempView("testSortBasedPartialAggregation") {
- val schema = StructType(
- StructField(s"key", IntegerType, true) :: StructField(s"value", StringType, true) :: Nil)
- val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, d.toString)))
- spark.createDataFrame(rowRDD, schema)
- .createOrReplaceTempView("testSortBasedPartialAggregation")
-
- // This test assumes a query below uses sort-based aggregations
- val planned = sql("SELECT MAX(value) FROM testSortBasedPartialAggregation GROUP BY key")
- .queryExecution.executedPlan
- // This line extracts both SortAggregate and Sort operators
- val extractedOps = planned.collect { case n if n.nodeName contains "Sort" => n }
- val aggOps = extractedOps.collect { case n if n.nodeName contains "SortAggregate" => n }
- assert(extractedOps.size == 4 && aggOps.size == 2,
- s"The plan $planned does not have correct sort-based partial aggregate pairs.")
- }
- }
-
- test("non-partial aggregation for aggregates") {
- withTempView("testNonPartialAggregation") {
- val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)
- val row = Row.fromSeq(Seq.fill(1)(null))
- val rowRDD = sparkContext.parallelize(row :: Nil)
- spark.createDataFrame(rowRDD, schema).repartition($"value")
- .createOrReplaceTempView("testNonPartialAggregation")
-
- val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation GROUP BY value")
- .queryExecution.executedPlan
-
- // If input data are already partitioned and the same columns are used in grouping keys and
- // aggregation values, no partial aggregation exist in query plans.
- val aggOps1 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
- assert(aggOps1.size == 1, s"The plan $planned1 has partial aggregations.")
-
- val planned2 = sql(
- """
- |SELECT t.value, SUM(DISTINCT t.value)
- |FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t
- |GROUP BY t.value
- """.stripMargin).queryExecution.executedPlan
-
- val aggOps2 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
- assert(aggOps2.size == 1, s"The plan $planned2 has partial aggregations.")
- }
+ testPartialAggregationPlan(query)
}
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {