aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala417
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala29
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala47
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala439
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala280
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala142
10 files changed, 422 insertions, 990 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index 7c2b8a9407..2c7c58e66b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst
private[spark] trait CatalystConf {
def caseSensitiveAnalysis: Boolean
-
- protected[spark] def specializeSingleDistinctAggPlanning: Boolean
}
/**
@@ -31,13 +29,8 @@ object EmptyConf extends CatalystConf {
override def caseSensitiveAnalysis: Boolean = {
throw new UnsupportedOperationException
}
-
- protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = {
- throw new UnsupportedOperationException
- }
}
/** A CatalystConf that can be used for local testing. */
case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf {
- protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = true
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
index 9c78f6d4cc..4e7d134102 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
@@ -123,15 +123,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
.filter(_.isDistinct)
.groupBy(_.aggregateFunction.children.toSet)
- val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) {
- // When the flag is set to specialize single distinct agg planning,
- // we will rely on our Aggregation strategy to handle queries with a single
- // distinct column.
- distinctAggGroups.size > 1
- } else {
- distinctAggGroups.size >= 1
- }
- if (shouldRewrite) {
+ // Aggregation strategy can handle the query with single distinct
+ if (distinctAggGroups.size > 1) {
// Create the attributes for the grouping id and the group by clause.
val gid = new AttributeReference("gid", IntegerType, false)()
val groupByMap = a.groupingExpressions.collect {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 58adf64e49..3d81926285 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -449,18 +449,6 @@ private[spark] object SQLConf {
doc = "When true, we could use `datasource`.`path` as table in SQL query"
)
- val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING =
- booleanConf("spark.sql.specializeSingleDistinctAggPlanning",
- defaultValue = Some(false),
- isPublic = false,
- doc = "When true, if a query only has a single distinct column and it has " +
- "grouping expressions, we will use our planner rule to handle this distinct " +
- "column (other cases are handled by DistinctAggregationRewriter). " +
- "When false, we will always use DistinctAggregationRewriter to plan " +
- "aggregation queries with DISTINCT keyword. This is an internal flag that is " +
- "used to benchmark the performance impact of using DistinctAggregationRewriter to " +
- "plan aggregation queries with a single distinct column.")
-
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
@@ -579,9 +567,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES)
- protected[spark] override def specializeSingleDistinctAggPlanning: Boolean =
- getConf(SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING)
-
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 008478a6a0..0c74df0aa5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -17,15 +17,15 @@
package org.apache.spark.sql.execution.aggregate
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import scala.collection.mutable.ArrayBuffer
-
/**
- * The base class of [[SortBasedAggregationIterator]].
+ * The base class of [[SortBasedAggregationIterator]] and [[TungstenAggregationIterator]].
* It mainly contains two parts:
* 1. It initializes aggregate functions.
* 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of
@@ -33,64 +33,58 @@ import scala.collection.mutable.ArrayBuffer
* is used to generate result.
*/
abstract class AggregationIterator(
- groupingKeyAttributes: Seq[Attribute],
- valueAttributes: Seq[Attribute],
- nonCompleteAggregateExpressions: Seq[AggregateExpression],
- nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression],
- completeAggregateAttributes: Seq[Attribute],
+ groupingExpressions: Seq[NamedExpression],
+ inputAttributes: Seq[Attribute],
+ aggregateExpressions: Seq[AggregateExpression],
+ aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- outputsUnsafeRows: Boolean)
- extends Iterator[InternalRow] with Logging {
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection))
+ extends Iterator[UnsafeRow] with Logging {
///////////////////////////////////////////////////////////////////////////
// Initializing functions.
///////////////////////////////////////////////////////////////////////////
- // An Seq of all AggregateExpressions.
- // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final
- // are at the beginning of the allAggregateExpressions.
- protected val allAggregateExpressions =
- nonCompleteAggregateExpressions ++ completeAggregateExpressions
-
- require(
- allAggregateExpressions.map(_.mode).distinct.length <= 2,
- s"$allAggregateExpressions are not supported becuase they have more than 2 distinct modes.")
-
/**
- * The distinct modes of AggregateExpressions. Right now, we can handle the following mode:
- * - Partial-only: all AggregateExpressions have the mode of Partial;
- * - PartialMerge-only: all AggregateExpressions have the mode of PartialMerge);
- * - Final-only: all AggregateExpressions have the mode of Final;
- * - Final-Complete: some AggregateExpressions have the mode of Final and
- * others have the mode of Complete;
- * - Complete-only: nonCompleteAggregateExpressions is empty and we have AggregateExpressions
- * with mode Complete in completeAggregateExpressions; and
- * - Grouping-only: there is no AggregateExpression.
- */
- protected val aggregationMode: (Option[AggregateMode], Option[AggregateMode]) =
- nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
- completeAggregateExpressions.map(_.mode).distinct.headOption
+ * The following combinations of AggregationMode are supported:
+ * - Partial
+ * - PartialMerge (for single distinct)
+ * - Partial and PartialMerge (for single distinct)
+ * - Final
+ * - Complete (for SortBasedAggregate with functions that does not support Partial)
+ * - Final and Complete (currently not used)
+ *
+ * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression
+ * could have a flag to tell it's final or not.
+ */
+ {
+ val modes = aggregateExpressions.map(_.mode).distinct.toSet
+ require(modes.size <= 2,
+ s"$aggregateExpressions are not supported because they have more than 2 distinct modes.")
+ require(modes.subsetOf(Set(Partial, PartialMerge)) || modes.subsetOf(Set(Final, Complete)),
+ s"$aggregateExpressions can't have Partial/PartialMerge and Final/Complete in the same time.")
+ }
// Initialize all AggregateFunctions by binding references if necessary,
// and set inputBufferOffset and mutableBufferOffset.
- protected val allAggregateFunctions: Array[AggregateFunction] = {
+ protected def initializeAggregateFunctions(
+ expressions: Seq[AggregateExpression],
+ startingInputBufferOffset: Int): Array[AggregateFunction] = {
var mutableBufferOffset = 0
- var inputBufferOffset: Int = initialInputBufferOffset
- val functions = new Array[AggregateFunction](allAggregateExpressions.length)
+ var inputBufferOffset: Int = startingInputBufferOffset
+ val functions = new Array[AggregateFunction](expressions.length)
var i = 0
- while (i < allAggregateExpressions.length) {
- val func = allAggregateExpressions(i).aggregateFunction
- val funcWithBoundReferences: AggregateFunction = allAggregateExpressions(i).mode match {
+ while (i < expressions.length) {
+ val func = expressions(i).aggregateFunction
+ val funcWithBoundReferences: AggregateFunction = expressions(i).mode match {
case Partial | Complete if func.isInstanceOf[ImperativeAggregate] =>
// We need to create BoundReferences if the function is not an
// expression-based aggregate function (it does not support code-gen) and the mode of
// this function is Partial or Complete because we will call eval of this
// function's children in the update method of this aggregate function.
// Those eval calls require BoundReferences to work.
- BindReferences.bindReference(func, valueAttributes)
+ BindReferences.bindReference(func, inputAttributes)
case _ =>
// We only need to set inputBufferOffset for aggregate functions with mode
// PartialMerge and Final.
@@ -117,15 +111,18 @@ abstract class AggregationIterator(
functions
}
+ protected val aggregateFunctions: Array[AggregateFunction] =
+ initializeAggregateFunctions(aggregateExpressions, initialInputBufferOffset)
+
// Positions of those imperative aggregate functions in allAggregateFunctions.
// For example, we have func1, func2, func3, func4 in aggregateFunctions, and
// func2 and func3 are imperative aggregate functions.
// ImperativeAggregateFunctionPositions will be [1, 2].
- private[this] val allImperativeAggregateFunctionPositions: Array[Int] = {
+ protected[this] val allImperativeAggregateFunctionPositions: Array[Int] = {
val positions = new ArrayBuffer[Int]()
var i = 0
- while (i < allAggregateFunctions.length) {
- allAggregateFunctions(i) match {
+ while (i < aggregateFunctions.length) {
+ aggregateFunctions(i) match {
case agg: DeclarativeAggregate =>
case _ => positions += i
}
@@ -134,17 +131,9 @@ abstract class AggregationIterator(
positions.toArray
}
- // All AggregateFunctions functions with mode Partial, PartialMerge, or Final.
- private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction] =
- allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
-
- // All imperative aggregate functions with mode Partial, PartialMerge, or Final.
- private[this] val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] =
- nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func }
-
// The projection used to initialize buffer values for all expression-based aggregates.
- private[this] val expressionAggInitialProjection = {
- val initExpressions = allAggregateFunctions.flatMap {
+ protected[this] val expressionAggInitialProjection = {
+ val initExpressions = aggregateFunctions.flatMap {
case ae: DeclarativeAggregate => ae.initialValues
// For the positions corresponding to imperative aggregate functions, we'll use special
// no-op expressions which are ignored during projection code-generation.
@@ -154,248 +143,112 @@ abstract class AggregationIterator(
}
// All imperative AggregateFunctions.
- private[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] =
+ protected[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] =
allImperativeAggregateFunctionPositions
- .map(allAggregateFunctions)
+ .map(aggregateFunctions)
.map(_.asInstanceOf[ImperativeAggregate])
- ///////////////////////////////////////////////////////////////////////////
- // Methods and fields used by sub-classes.
- ///////////////////////////////////////////////////////////////////////////
-
// Initializing functions used to process a row.
- protected val processRow: (MutableRow, InternalRow) => Unit = {
- val rowToBeProcessed = new JoinedRow
- val aggregationBufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
- aggregationMode match {
- // Partial-only
- case (Some(Partial), None) =>
- val updateExpressions = nonCompleteAggregateFunctions.flatMap {
- case ae: DeclarativeAggregate => ae.updateExpressions
- case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
- }
- val expressionAggUpdateProjection =
- newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
-
- (currentBuffer: MutableRow, row: InternalRow) => {
- expressionAggUpdateProjection.target(currentBuffer)
- // Process all expression-based aggregate functions.
- expressionAggUpdateProjection(rowToBeProcessed(currentBuffer, row))
- // Process all imperative aggregate functions.
- var i = 0
- while (i < nonCompleteImperativeAggregateFunctions.length) {
- nonCompleteImperativeAggregateFunctions(i).update(currentBuffer, row)
- i += 1
- }
- }
-
- // PartialMerge-only or Final-only
- case (Some(PartialMerge), None) | (Some(Final), None) =>
- val inputAggregationBufferSchema = if (initialInputBufferOffset == 0) {
- // If initialInputBufferOffset, the input value does not contain
- // grouping keys.
- // This part is pretty hacky.
- allAggregateFunctions.flatMap(_.inputAggBufferAttributes).toSeq
- } else {
- groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.inputAggBufferAttributes)
- }
- // val inputAggregationBufferSchema =
- // groupingKeyAttributes ++
- // allAggregateFunctions.flatMap(_.cloneBufferAttributes)
- val mergeExpressions = nonCompleteAggregateFunctions.flatMap {
- case ae: DeclarativeAggregate => ae.mergeExpressions
- case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
- }
- // This projection is used to merge buffer values for all expression-based aggregates.
- val expressionAggMergeProjection =
- newMutableProjection(
- mergeExpressions,
- aggregationBufferSchema ++ inputAggregationBufferSchema)()
-
- (currentBuffer: MutableRow, row: InternalRow) => {
- // Process all expression-based aggregate functions.
- expressionAggMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row))
- // Process all imperative aggregate functions.
- var i = 0
- while (i < nonCompleteImperativeAggregateFunctions.length) {
- nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row)
- i += 1
- }
- }
-
- // Final-Complete
- case (Some(Final), Some(Complete)) =>
- val completeAggregateFunctions: Array[AggregateFunction] =
- allAggregateFunctions.takeRight(completeAggregateExpressions.length)
- // All imperative aggregate functions with mode Complete.
- val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
- completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
-
- // The first initialInputBufferOffset values of the input aggregation buffer is
- // for grouping expressions and distinct columns.
- val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset)
-
- val completeOffsetExpressions =
- Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
- // We do not touch buffer values of aggregate functions with the Final mode.
- val finalOffsetExpressions =
- Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
-
- val mergeInputSchema =
- aggregationBufferSchema ++
- groupingAttributesAndDistinctColumns ++
- nonCompleteAggregateFunctions.flatMap(_.inputAggBufferAttributes)
- val mergeExpressions =
- nonCompleteAggregateFunctions.flatMap {
- case ae: DeclarativeAggregate => ae.mergeExpressions
- case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
- } ++ completeOffsetExpressions
- val finalExpressionAggMergeProjection =
- newMutableProjection(mergeExpressions, mergeInputSchema)()
-
- val updateExpressions =
- finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
- case ae: DeclarativeAggregate => ae.updateExpressions
- case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
- }
- val completeExpressionAggUpdateProjection =
- newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
-
- (currentBuffer: MutableRow, row: InternalRow) => {
- val input = rowToBeProcessed(currentBuffer, row)
- // For all aggregate functions with mode Complete, update buffers.
- completeExpressionAggUpdateProjection.target(currentBuffer)(input)
- var i = 0
- while (i < completeImperativeAggregateFunctions.length) {
- completeImperativeAggregateFunctions(i).update(currentBuffer, row)
- i += 1
- }
-
- // For all aggregate functions with mode Final, merge buffers.
- finalExpressionAggMergeProjection.target(currentBuffer)(input)
- i = 0
- while (i < nonCompleteImperativeAggregateFunctions.length) {
- nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row)
- i += 1
+ protected def generateProcessRow(
+ expressions: Seq[AggregateExpression],
+ functions: Seq[AggregateFunction],
+ inputAttributes: Seq[Attribute]): (MutableRow, InternalRow) => Unit = {
+ val joinedRow = new JoinedRow
+ if (expressions.nonEmpty) {
+ val mergeExpressions = functions.zipWithIndex.flatMap {
+ case (ae: DeclarativeAggregate, i) =>
+ expressions(i).mode match {
+ case Partial | Complete => ae.updateExpressions
+ case PartialMerge | Final => ae.mergeExpressions
}
- }
-
- // Complete-only
- case (None, Some(Complete)) =>
- val completeAggregateFunctions: Array[AggregateFunction] =
- allAggregateFunctions.takeRight(completeAggregateExpressions.length)
- // All imperative aggregate functions with mode Complete.
- val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
- completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
-
- val updateExpressions =
- completeAggregateFunctions.flatMap {
- case ae: DeclarativeAggregate => ae.updateExpressions
- case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
- }
- val completeExpressionAggUpdateProjection =
- newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
-
- (currentBuffer: MutableRow, row: InternalRow) => {
- val input = rowToBeProcessed(currentBuffer, row)
- // For all aggregate functions with mode Complete, update buffers.
- completeExpressionAggUpdateProjection.target(currentBuffer)(input)
- var i = 0
- while (i < completeImperativeAggregateFunctions.length) {
- completeImperativeAggregateFunctions(i).update(currentBuffer, row)
- i += 1
+ case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ }
+ val updateFunctions = functions.zipWithIndex.collect {
+ case (ae: ImperativeAggregate, i) =>
+ expressions(i).mode match {
+ case Partial | Complete =>
+ (buffer: MutableRow, row: InternalRow) => ae.update(buffer, row)
+ case PartialMerge | Final =>
+ (buffer: MutableRow, row: InternalRow) => ae.merge(buffer, row)
}
+ }
+ // This projection is used to merge buffer values for all expression-based aggregates.
+ val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes)
+ val updateProjection =
+ newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes)()
+
+ (currentBuffer: MutableRow, row: InternalRow) => {
+ // Process all expression-based aggregate functions.
+ updateProjection.target(currentBuffer)(joinedRow(currentBuffer, row))
+ // Process all imperative aggregate functions.
+ var i = 0
+ while (i < updateFunctions.length) {
+ updateFunctions(i)(currentBuffer, row)
+ i += 1
}
-
+ }
+ } else {
// Grouping only.
- case (None, None) => (currentBuffer: MutableRow, row: InternalRow) => {}
-
- case other =>
- sys.error(
- s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " +
- s"support evaluate modes $other in this iterator.")
+ (currentBuffer: MutableRow, row: InternalRow) => {}
}
}
- // Initializing the function used to generate the output row.
- protected val generateOutput: (InternalRow, MutableRow) => InternalRow = {
- val rowToBeEvaluated = new JoinedRow
- val safeOutputRow = new SpecificMutableRow(resultExpressions.map(_.dataType))
- val mutableOutput = if (outputsUnsafeRows) {
- UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow)
- } else {
- safeOutputRow
- }
-
- aggregationMode match {
- // Partial-only or PartialMerge-only: every output row is basically the values of
- // the grouping expressions and the corresponding aggregation buffer.
- case (Some(Partial), None) | (Some(PartialMerge), None) =>
- // Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not
- // support generic getter), we create a mutable projection to output the
- // JoinedRow(currentGroupingKey, currentBuffer)
- val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.aggBufferAttributes)
- val resultProjection =
- newMutableProjection(
- groupingKeyAttributes ++ bufferSchema,
- groupingKeyAttributes ++ bufferSchema)()
- resultProjection.target(mutableOutput)
-
- (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
- resultProjection(rowToBeEvaluated(currentGroupingKey, currentBuffer))
- // rowToBeEvaluated(currentGroupingKey, currentBuffer)
- }
+ protected val processRow: (MutableRow, InternalRow) => Unit =
+ generateProcessRow(aggregateExpressions, aggregateFunctions, inputAttributes)
- // Final-only, Complete-only and Final-Complete: every output row contains values representing
- // resultExpressions.
- case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
- val bufferSchemata =
- allAggregateFunctions.flatMap(_.aggBufferAttributes)
- val evalExpressions = allAggregateFunctions.map {
- case ae: DeclarativeAggregate => ae.evaluateExpression
- case agg: AggregateFunction => NoOp
- }
- val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)()
- val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
- // TODO: Use unsafe row.
- val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType))
- expressionAggEvalProjection.target(aggregateResult)
- val resultProjection =
- newMutableProjection(
- resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)()
- resultProjection.target(mutableOutput)
+ protected val groupingProjection: UnsafeProjection =
+ UnsafeProjection.create(groupingExpressions, inputAttributes)
+ protected val groupingAttributes = groupingExpressions.map(_.toAttribute)
- (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
- // Generate results for all expression-based aggregate functions.
- expressionAggEvalProjection(currentBuffer)
- // Generate results for all imperative aggregate functions.
- var i = 0
- while (i < allImperativeAggregateFunctions.length) {
- aggregateResult.update(
- allImperativeAggregateFunctionPositions(i),
- allImperativeAggregateFunctions(i).eval(currentBuffer))
- i += 1
- }
- resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult))
+ // Initializing the function used to generate the output row.
+ protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = {
+ val joinedRow = new JoinedRow
+ val modes = aggregateExpressions.map(_.mode).distinct
+ val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes)
+ if (modes.contains(Final) || modes.contains(Complete)) {
+ val evalExpressions = aggregateFunctions.map {
+ case ae: DeclarativeAggregate => ae.evaluateExpression
+ case agg: AggregateFunction => NoOp
+ }
+ val aggregateResult = new SpecificMutableRow(aggregateAttributes.map(_.dataType))
+ val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)()
+ expressionAggEvalProjection.target(aggregateResult)
+
+ val resultProjection =
+ UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes)
+
+ (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
+ // Generate results for all expression-based aggregate functions.
+ expressionAggEvalProjection(currentBuffer)
+ // Generate results for all imperative aggregate functions.
+ var i = 0
+ while (i < allImperativeAggregateFunctions.length) {
+ aggregateResult.update(
+ allImperativeAggregateFunctionPositions(i),
+ allImperativeAggregateFunctions(i).eval(currentBuffer))
+ i += 1
}
-
+ resultProjection(joinedRow(currentGroupingKey, aggregateResult))
+ }
+ } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
+ val resultProjection = UnsafeProjection.create(
+ groupingAttributes ++ bufferAttributes,
+ groupingAttributes ++ bufferAttributes)
+ (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
+ resultProjection(joinedRow(currentGroupingKey, currentBuffer))
+ }
+ } else {
// Grouping-only: we only output values of grouping expressions.
- case (None, None) =>
- val resultProjection =
- newMutableProjection(resultExpressions, groupingKeyAttributes)()
- resultProjection.target(mutableOutput)
-
- (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
- resultProjection(currentGroupingKey)
- }
-
- case other =>
- sys.error(
- s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " +
- s"support evaluate modes $other in this iterator.")
+ val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
+ (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
+ resultProjection(currentGroupingKey)
+ }
}
}
+ protected val generateOutput: (UnsafeRow, MutableRow) => UnsafeRow =
+ generateResultProjection()
+
/** Initializes buffer values for all aggregate functions. */
protected def initializeBuffer(buffer: MutableRow): Unit = {
expressionAggInitialProjection.target(buffer)(EmptyRow)
@@ -405,10 +258,4 @@ abstract class AggregationIterator(
i += 1
}
}
-
- /**
- * Creates a new aggregation buffer and initializes buffer values
- * for all aggregate functions.
- */
- protected def newBuffer: MutableRow
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
index ee982453c3..c5470a6989 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -29,10 +29,8 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
case class SortBasedAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
- nonCompleteAggregateExpressions: Seq[AggregateExpression],
- nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression],
- completeAggregateAttributes: Seq[Attribute],
+ aggregateExpressions: Seq[AggregateExpression],
+ aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
@@ -42,10 +40,8 @@ case class SortBasedAggregate(
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
- override def outputsUnsafeRows: Boolean = false
-
+ override def outputsUnsafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = false
-
override def canProcessSafeRows: Boolean = true
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
@@ -76,31 +72,24 @@ case class SortBasedAggregate(
if (!hasInput && groupingExpressions.nonEmpty) {
// This is a grouped aggregate and the input iterator is empty,
// so return an empty iterator.
- Iterator[InternalRow]()
+ Iterator[UnsafeRow]()
} else {
- val groupingKeyProjection =
- UnsafeProjection.create(groupingExpressions, child.output)
-
val outputIter = new SortBasedAggregationIterator(
- groupingKeyProjection,
- groupingExpressions.map(_.toAttribute),
+ groupingExpressions,
child.output,
iter,
- nonCompleteAggregateExpressions,
- nonCompleteAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
+ aggregateExpressions,
+ aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
- outputsUnsafeRows,
numInputRows,
numOutputRows)
if (!hasInput && groupingExpressions.isEmpty) {
// There is no input and there is no grouping expressions.
// We need to output a single row as the output.
numOutputRows += 1
- Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
+ Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
} else {
outputIter
}
@@ -109,7 +98,7 @@ case class SortBasedAggregate(
}
override def simpleString: String = {
- val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
+ val allAggregateExpressions = aggregateExpressions
val keyString = groupingExpressions.mkString("[", ",", "]")
val functionString = allAggregateExpressions.mkString("[", ",", "]")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index fe5c3195f8..ac920aa8bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -24,37 +24,34 @@ import org.apache.spark.sql.execution.metric.LongSQLMetric
/**
* An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been
- * sorted by values of [[groupingKeyAttributes]].
+ * sorted by values of [[groupingExpressions]].
*/
class SortBasedAggregationIterator(
- groupingKeyProjection: InternalRow => InternalRow,
- groupingKeyAttributes: Seq[Attribute],
+ groupingExpressions: Seq[NamedExpression],
valueAttributes: Seq[Attribute],
inputIterator: Iterator[InternalRow],
- nonCompleteAggregateExpressions: Seq[AggregateExpression],
- nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression],
- completeAggregateAttributes: Seq[Attribute],
+ aggregateExpressions: Seq[AggregateExpression],
+ aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- outputsUnsafeRows: Boolean,
numInputRows: LongSQLMetric,
numOutputRows: LongSQLMetric)
extends AggregationIterator(
- groupingKeyAttributes,
+ groupingExpressions,
valueAttributes,
- nonCompleteAggregateExpressions,
- nonCompleteAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
+ aggregateExpressions,
+ aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
- newMutableProjection,
- outputsUnsafeRows) {
-
- override protected def newBuffer: MutableRow = {
- val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
+ newMutableProjection) {
+
+ /**
+ * Creates a new aggregation buffer and initializes buffer values
+ * for all aggregate functions.
+ */
+ private def newBuffer: MutableRow = {
+ val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes)
val bufferRowSize: Int = bufferSchema.length
val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
@@ -76,10 +73,10 @@ class SortBasedAggregationIterator(
///////////////////////////////////////////////////////////////////////////
// The partition key of the current partition.
- private[this] var currentGroupingKey: InternalRow = _
+ private[this] var currentGroupingKey: UnsafeRow = _
// The partition key of next partition.
- private[this] var nextGroupingKey: InternalRow = _
+ private[this] var nextGroupingKey: UnsafeRow = _
// The first row of next partition.
private[this] var firstRowInNextGroup: InternalRow = _
@@ -94,7 +91,7 @@ class SortBasedAggregationIterator(
if (inputIterator.hasNext) {
initializeBuffer(sortBasedAggregationBuffer)
val inputRow = inputIterator.next()
- nextGroupingKey = groupingKeyProjection(inputRow).copy()
+ nextGroupingKey = groupingProjection(inputRow).copy()
firstRowInNextGroup = inputRow.copy()
numInputRows += 1
sortedInputHasNewGroup = true
@@ -120,7 +117,7 @@ class SortBasedAggregationIterator(
while (!findNextPartition && inputIterator.hasNext) {
// Get the grouping key.
val currentRow = inputIterator.next()
- val groupingKey = groupingKeyProjection(currentRow)
+ val groupingKey = groupingProjection(currentRow)
numInputRows += 1
// Check if the current row belongs the current input row.
@@ -146,7 +143,7 @@ class SortBasedAggregationIterator(
override final def hasNext: Boolean = sortedInputHasNewGroup
- override final def next(): InternalRow = {
+ override final def next(): UnsafeRow = {
if (hasNext) {
// Process the current group.
processCurrentSortedGroup()
@@ -162,8 +159,8 @@ class SortBasedAggregationIterator(
}
}
- def outputForEmptyGroupingKeyWithoutInput(): InternalRow = {
+ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
initializeBuffer(sortBasedAggregationBuffer)
- generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer)
+ generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 920de615e1..b8849c8270 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -30,21 +30,18 @@ import org.apache.spark.sql.types.StructType
case class TungstenAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
- nonCompleteAggregateExpressions: Seq[AggregateExpression],
- nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression],
- completeAggregateAttributes: Seq[Attribute],
+ aggregateExpressions: Seq[AggregateExpression],
+ aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryNode {
private[this] val aggregateBufferAttributes = {
- (nonCompleteAggregateExpressions ++ completeAggregateExpressions)
- .flatMap(_.aggregateFunction.aggBufferAttributes)
+ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}
- require(TungstenAggregate.supportsAggregate(groupingExpressions, aggregateBufferAttributes))
+ require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes))
override private[sql] lazy val metrics = Map(
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
@@ -53,9 +50,7 @@ case class TungstenAggregate(
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
override def outputsUnsafeRows: Boolean = true
-
override def canProcessUnsafeRows: Boolean = true
-
override def canProcessSafeRows: Boolean = true
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
@@ -94,10 +89,8 @@ case class TungstenAggregate(
val aggregationIterator =
new TungstenAggregationIterator(
groupingExpressions,
- nonCompleteAggregateExpressions,
- nonCompleteAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
+ aggregateExpressions,
+ aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
@@ -119,7 +112,7 @@ case class TungstenAggregate(
}
override def simpleString: String = {
- val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
+ val allAggregateExpressions = aggregateExpressions
testFallbackStartsAt match {
case None =>
@@ -135,9 +128,7 @@ case class TungstenAggregate(
}
object TungstenAggregate {
- def supportsAggregate(
- groupingExpressions: Seq[Expression],
- aggregateBufferAttributes: Seq[Attribute]): Boolean = {
+ def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = {
val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 0439144392..582fdbe547 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -17,17 +17,15 @@
package org.apache.spark.sql.execution.aggregate
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.unsafe.KVIterator
-import org.apache.spark.{InternalAccumulator, Logging, TaskContext}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.KVIterator
+import org.apache.spark.{InternalAccumulator, Logging, TaskContext}
/**
* An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s.
@@ -63,15 +61,11 @@ import org.apache.spark.sql.types.StructType
*
* @param groupingExpressions
* expressions for grouping keys
- * @param nonCompleteAggregateExpressions
+ * @param aggregateExpressions
* [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Partial]],
* [[PartialMerge]], or [[Final]].
- * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions'
+ * @param aggregateAttributes the attributes of the aggregateExpressions'
* outputs when they are stored in the final aggregation buffer.
- * @param completeAggregateExpressions
- * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Complete]].
- * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs
- * when they are stored in the final aggregation buffer.
* @param resultExpressions
* expressions for generating output rows.
* @param newMutableProjection
@@ -83,10 +77,8 @@ import org.apache.spark.sql.types.StructType
*/
class TungstenAggregationIterator(
groupingExpressions: Seq[NamedExpression],
- nonCompleteAggregateExpressions: Seq[AggregateExpression],
- nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression],
- completeAggregateAttributes: Seq[Attribute],
+ aggregateExpressions: Seq[AggregateExpression],
+ aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
@@ -97,378 +89,62 @@ class TungstenAggregationIterator(
numOutputRows: LongSQLMetric,
dataSize: LongSQLMetric,
spillSize: LongSQLMetric)
- extends Iterator[UnsafeRow] with Logging {
+ extends AggregationIterator(
+ groupingExpressions,
+ originalInputAttributes,
+ aggregateExpressions,
+ aggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection) with Logging {
///////////////////////////////////////////////////////////////////////////
// Part 1: Initializing aggregate functions.
///////////////////////////////////////////////////////////////////////////
- // A Seq containing all AggregateExpressions.
- // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final
- // are at the beginning of the allAggregateExpressions.
- private[this] val allAggregateExpressions: Seq[AggregateExpression] =
- nonCompleteAggregateExpressions ++ completeAggregateExpressions
-
- // Check to make sure we do not have more than three modes in our AggregateExpressions.
- // If we have, users are hitting a bug and we throw an IllegalStateException.
- if (allAggregateExpressions.map(_.mode).distinct.length > 2) {
- throw new IllegalStateException(
- s"$allAggregateExpressions should have no more than 2 kinds of modes.")
- }
-
// Remember spill data size of this task before execute this operator so that we can
// figure out how many bytes we spilled for this operator.
private val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled
- //
- // The modes of AggregateExpressions. Right now, we can handle the following mode:
- // - Partial-only:
- // All AggregateExpressions have the mode of Partial.
- // For this case, aggregationMode is (Some(Partial), None).
- // - PartialMerge-only:
- // All AggregateExpressions have the mode of PartialMerge).
- // For this case, aggregationMode is (Some(PartialMerge), None).
- // - Final-only:
- // All AggregateExpressions have the mode of Final.
- // For this case, aggregationMode is (Some(Final), None).
- // - Final-Complete:
- // Some AggregateExpressions have the mode of Final and
- // others have the mode of Complete. For this case,
- // aggregationMode is (Some(Final), Some(Complete)).
- // - Complete-only:
- // nonCompleteAggregateExpressions is empty and we have AggregateExpressions
- // with mode Complete in completeAggregateExpressions. For this case,
- // aggregationMode is (None, Some(Complete)).
- // - Grouping-only:
- // There is no AggregateExpression. For this case, AggregationMode is (None,None).
- //
- private[this] var aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = {
- nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
- completeAggregateExpressions.map(_.mode).distinct.headOption
- }
-
- // Initialize all AggregateFunctions by binding references, if necessary,
- // and setting inputBufferOffset and mutableBufferOffset.
- private def initializeAllAggregateFunctions(
- startingInputBufferOffset: Int): Array[AggregateFunction] = {
- var mutableBufferOffset = 0
- var inputBufferOffset: Int = startingInputBufferOffset
- val functions = new Array[AggregateFunction](allAggregateExpressions.length)
- var i = 0
- while (i < allAggregateExpressions.length) {
- val func = allAggregateExpressions(i).aggregateFunction
- val aggregateExpressionIsNonComplete = i < nonCompleteAggregateExpressions.length
- // We need to use this mode instead of func.mode in order to handle aggregation mode switching
- // when switching to sort-based aggregation:
- val mode = if (aggregateExpressionIsNonComplete) aggregationMode._1 else aggregationMode._2
- val funcWithBoundReferences = mode match {
- case Some(Partial) | Some(Complete) if func.isInstanceOf[ImperativeAggregate] =>
- // We need to create BoundReferences if the function is not an
- // expression-based aggregate function (it does not support code-gen) and the mode of
- // this function is Partial or Complete because we will call eval of this
- // function's children in the update method of this aggregate function.
- // Those eval calls require BoundReferences to work.
- BindReferences.bindReference(func, originalInputAttributes)
- case _ =>
- // We only need to set inputBufferOffset for aggregate functions with mode
- // PartialMerge and Final.
- val updatedFunc = func match {
- case function: ImperativeAggregate =>
- function.withNewInputAggBufferOffset(inputBufferOffset)
- case function => function
- }
- inputBufferOffset += func.aggBufferSchema.length
- updatedFunc
- }
- val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match {
- case function: ImperativeAggregate =>
- // Set mutableBufferOffset for this function. It is important that setting
- // mutableBufferOffset happens after all potential bindReference operations
- // because bindReference will create a new instance of the function.
- function.withNewMutableAggBufferOffset(mutableBufferOffset)
- case function => function
- }
- mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length
- functions(i) = funcWithUpdatedAggBufferOffset
- i += 1
- }
- functions
- }
-
- private[this] var allAggregateFunctions: Array[AggregateFunction] =
- initializeAllAggregateFunctions(initialInputBufferOffset)
-
- // Positions of those imperative aggregate functions in allAggregateFunctions.
- // For example, say that we have func1, func2, func3, func4 in aggregateFunctions, and
- // func2 and func3 are imperative aggregate functions. Then
- // allImperativeAggregateFunctionPositions will be [1, 2]. Note that this does not need to be
- // updated when falling back to sort-based aggregation because the positions of the aggregate
- // functions do not change in that case.
- private[this] val allImperativeAggregateFunctionPositions: Array[Int] = {
- val positions = new ArrayBuffer[Int]()
- var i = 0
- while (i < allAggregateFunctions.length) {
- allAggregateFunctions(i) match {
- case agg: DeclarativeAggregate =>
- case _ => positions += i
- }
- i += 1
- }
- positions.toArray
- }
-
///////////////////////////////////////////////////////////////////////////
// Part 2: Methods and fields used by setting aggregation buffer values,
// processing input rows from inputIter, and generating output
// rows.
///////////////////////////////////////////////////////////////////////////
- // The projection used to initialize buffer values for all expression-based aggregates.
- // Note that this projection does not need to be updated when switching to sort-based aggregation
- // because the schema of empty aggregation buffers does not change in that case.
- private[this] val expressionAggInitialProjection: MutableProjection = {
- val initExpressions = allAggregateFunctions.flatMap {
- case ae: DeclarativeAggregate => ae.initialValues
- // For the positions corresponding to imperative aggregate functions, we'll use special
- // no-op expressions which are ignored during projection code-generation.
- case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp)
- }
- newMutableProjection(initExpressions, Nil)()
- }
-
// Creates a new aggregation buffer and initializes buffer values.
- // This function should be only called at most three times (when we create the hash map,
- // when we switch to sort-based aggregation, and when we create the re-used buffer for
- // sort-based aggregation).
+ // This function should be only called at most two times (when we create the hash map,
+ // and when we create the re-used buffer for sort-based aggregation).
private def createNewAggregationBuffer(): UnsafeRow = {
- val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
+ val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes)
val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType))
.apply(new GenericMutableRow(bufferSchema.length))
// Initialize declarative aggregates' buffer values
expressionAggInitialProjection.target(buffer)(EmptyRow)
// Initialize imperative aggregates' buffer values
- allAggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer))
+ aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer))
buffer
}
- // Creates a function used to process a row based on the given inputAttributes.
- private def generateProcessRow(
- inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = {
-
- val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes)
- val joinedRow = new JoinedRow()
-
- aggregationMode match {
- // Partial-only
- case (Some(Partial), None) =>
- val updateExpressions = allAggregateFunctions.flatMap {
- case ae: DeclarativeAggregate => ae.updateExpressions
- case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
- }
- val imperativeAggregateFunctions: Array[ImperativeAggregate] =
- allAggregateFunctions.collect { case func: ImperativeAggregate => func}
- val expressionAggUpdateProjection =
- newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
-
- (currentBuffer: UnsafeRow, row: InternalRow) => {
- expressionAggUpdateProjection.target(currentBuffer)
- // Process all expression-based aggregate functions.
- expressionAggUpdateProjection(joinedRow(currentBuffer, row))
- // Process all imperative aggregate functions
- var i = 0
- while (i < imperativeAggregateFunctions.length) {
- imperativeAggregateFunctions(i).update(currentBuffer, row)
- i += 1
- }
- }
-
- // PartialMerge-only or Final-only
- case (Some(PartialMerge), None) | (Some(Final), None) =>
- val mergeExpressions = allAggregateFunctions.flatMap {
- case ae: DeclarativeAggregate => ae.mergeExpressions
- case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
- }
- val imperativeAggregateFunctions: Array[ImperativeAggregate] =
- allAggregateFunctions.collect { case func: ImperativeAggregate => func}
- // This projection is used to merge buffer values for all expression-based aggregates.
- val expressionAggMergeProjection =
- newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)()
-
- (currentBuffer: UnsafeRow, row: InternalRow) => {
- // Process all expression-based aggregate functions.
- expressionAggMergeProjection.target(currentBuffer)(joinedRow(currentBuffer, row))
- // Process all imperative aggregate functions.
- var i = 0
- while (i < imperativeAggregateFunctions.length) {
- imperativeAggregateFunctions(i).merge(currentBuffer, row)
- i += 1
- }
- }
-
- // Final-Complete
- case (Some(Final), Some(Complete)) =>
- val completeAggregateFunctions: Array[AggregateFunction] =
- allAggregateFunctions.takeRight(completeAggregateExpressions.length)
- val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
- completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
- val nonCompleteAggregateFunctions: Array[AggregateFunction] =
- allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
- val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] =
- nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func }
-
- val completeOffsetExpressions =
- Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
- val mergeExpressions =
- nonCompleteAggregateFunctions.flatMap {
- case ae: DeclarativeAggregate => ae.mergeExpressions
- case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
- } ++ completeOffsetExpressions
- val finalMergeProjection =
- newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)()
-
- // We do not touch buffer values of aggregate functions with the Final mode.
- val finalOffsetExpressions =
- Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
- val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
- case ae: DeclarativeAggregate => ae.updateExpressions
- case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
- }
- val completeUpdateProjection =
- newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
-
- (currentBuffer: UnsafeRow, row: InternalRow) => {
- val input = joinedRow(currentBuffer, row)
- // For all aggregate functions with mode Complete, update buffers.
- completeUpdateProjection.target(currentBuffer)(input)
- var i = 0
- while (i < completeImperativeAggregateFunctions.length) {
- completeImperativeAggregateFunctions(i).update(currentBuffer, row)
- i += 1
- }
-
- // For all aggregate functions with mode Final, merge buffer values in row to
- // currentBuffer.
- finalMergeProjection.target(currentBuffer)(input)
- i = 0
- while (i < nonCompleteImperativeAggregateFunctions.length) {
- nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row)
- i += 1
- }
- }
-
- // Complete-only
- case (None, Some(Complete)) =>
- val completeAggregateFunctions: Array[AggregateFunction] =
- allAggregateFunctions.takeRight(completeAggregateExpressions.length)
- // All imperative aggregate functions with mode Complete.
- val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
- completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
-
- val updateExpressions = completeAggregateFunctions.flatMap {
- case ae: DeclarativeAggregate => ae.updateExpressions
- case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
- }
- val completeExpressionAggUpdateProjection =
- newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
-
- (currentBuffer: UnsafeRow, row: InternalRow) => {
- // For all aggregate functions with mode Complete, update buffers.
- completeExpressionAggUpdateProjection.target(currentBuffer)(joinedRow(currentBuffer, row))
- var i = 0
- while (i < completeImperativeAggregateFunctions.length) {
- completeImperativeAggregateFunctions(i).update(currentBuffer, row)
- i += 1
- }
- }
-
- // Grouping only.
- case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {}
-
- case other =>
- throw new IllegalStateException(
- s"${aggregationMode} should not be passed into TungstenAggregationIterator.")
- }
- }
-
// Creates a function used to generate output rows.
- private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = {
-
- val groupingAttributes = groupingExpressions.map(_.toAttribute)
- val bufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes)
-
- aggregationMode match {
- // Partial-only or PartialMerge-only: every output row is basically the values of
- // the grouping expressions and the corresponding aggregation buffer.
- case (Some(Partial), None) | (Some(PartialMerge), None) =>
- val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
- val bufferSchema = StructType.fromAttributes(bufferAttributes)
- val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
-
- (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
- unsafeRowJoiner.join(currentGroupingKey, currentBuffer)
- }
-
- // Final-only, Complete-only and Final-Complete: a output row is generated based on
- // resultExpressions.
- case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
- val joinedRow = new JoinedRow()
- val evalExpressions = allAggregateFunctions.map {
- case ae: DeclarativeAggregate => ae.evaluateExpression
- case agg: AggregateFunction => NoOp
- }
- val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)()
- // These are the attributes of the row produced by `expressionAggEvalProjection`
- val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
- val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType))
- expressionAggEvalProjection.target(aggregateResult)
- val resultProjection =
- UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema)
-
- val allImperativeAggregateFunctions: Array[ImperativeAggregate] =
- allAggregateFunctions.collect { case func: ImperativeAggregate => func}
-
- (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
- // Generate results for all expression-based aggregate functions.
- expressionAggEvalProjection(currentBuffer)
- // Generate results for all imperative aggregate functions.
- var i = 0
- while (i < allImperativeAggregateFunctions.length) {
- aggregateResult.update(
- allImperativeAggregateFunctionPositions(i),
- allImperativeAggregateFunctions(i).eval(currentBuffer))
- i += 1
- }
- resultProjection(joinedRow(currentGroupingKey, aggregateResult))
- }
-
- // Grouping-only: a output row is generated from values of grouping expressions.
- case (None, None) =>
- val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
-
- (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
- resultProjection(currentGroupingKey)
- }
-
- case other =>
- throw new IllegalStateException(
- s"${aggregationMode} should not be passed into TungstenAggregationIterator.")
+ override protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = {
+ val modes = aggregateExpressions.map(_.mode).distinct
+ if (modes.nonEmpty && !modes.contains(Final) && !modes.contains(Complete)) {
+ // Fast path for partial aggregation, UnsafeRowJoiner is usually faster than projection
+ val groupingAttributes = groupingExpressions.map(_.toAttribute)
+ val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes)
+ val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
+ val bufferSchema = StructType.fromAttributes(bufferAttributes)
+ val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
+
+ (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
+ unsafeRowJoiner.join(currentGroupingKey, currentBuffer.asInstanceOf[UnsafeRow])
+ }
+ } else {
+ super.generateResultProjection()
}
}
- // An UnsafeProjection used to extract grouping keys from the input rows.
- private[this] val groupProjection =
- UnsafeProjection.create(groupingExpressions, originalInputAttributes)
-
- // A function used to process a input row. Its first argument is the aggregation buffer
- // and the second argument is the input row.
- private[this] var processRow: (UnsafeRow, InternalRow) => Unit =
- generateProcessRow(originalInputAttributes)
-
- // A function used to generate output rows based on the grouping keys (first argument)
- // and the corresponding aggregation buffer (second argument).
- private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow =
- generateResultProjection()
-
// An aggregation buffer containing initial buffer values. It is used to
// initialize other aggregation buffers.
private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer()
@@ -482,7 +158,7 @@ class TungstenAggregationIterator(
// all groups and their corresponding aggregation buffers for hash-based aggregation.
private[this] val hashMap = new UnsafeFixedWidthAggregationMap(
initialAggregationBuffer,
- StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)),
+ StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)),
StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
TaskContext.get().taskMemoryManager(),
1024 * 16, // initial capacity
@@ -499,7 +175,7 @@ class TungstenAggregationIterator(
if (groupingExpressions.isEmpty) {
// If there is no grouping expressions, we can just reuse the same buffer over and over again.
// Note that it would be better to eliminate the hash map entirely in the future.
- val groupingKey = groupProjection.apply(null)
+ val groupingKey = groupingProjection.apply(null)
val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
while (inputIter.hasNext) {
val newInput = inputIter.next()
@@ -511,7 +187,7 @@ class TungstenAggregationIterator(
while (inputIter.hasNext) {
val newInput = inputIter.next()
numInputRows += 1
- val groupingKey = groupProjection.apply(newInput)
+ val groupingKey = groupingProjection.apply(newInput)
var buffer: UnsafeRow = null
if (i < fallbackStartsAt) {
buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
@@ -565,25 +241,18 @@ class TungstenAggregationIterator(
private def switchToSortBasedAggregation(): Unit = {
logInfo("falling back to sort based aggregation.")
- // Set aggregationMode, processRow, and generateOutput for sort-based aggregation.
- val newAggregationMode = aggregationMode match {
- case (Some(Partial), None) => (Some(PartialMerge), None)
- case (None, Some(Complete)) => (Some(Final), None)
- case (Some(Final), Some(Complete)) => (Some(Final), None)
+ // Basically the value of the KVIterator returned by externalSorter
+ // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it.
+ val newExpressions = aggregateExpressions.map {
+ case agg @ AggregateExpression(_, Partial, _) =>
+ agg.copy(mode = PartialMerge)
+ case agg @ AggregateExpression(_, Complete, _) =>
+ agg.copy(mode = Final)
case other => other
}
- aggregationMode = newAggregationMode
-
- allAggregateFunctions = initializeAllAggregateFunctions(startingInputBufferOffset = 0)
-
- // Basically the value of the KVIterator returned by externalSorter
- // will just aggregation buffer. At here, we use inputAggBufferAttributes.
- val newInputAttributes: Seq[Attribute] =
- allAggregateFunctions.flatMap(_.inputAggBufferAttributes)
-
- // Set up new processRow and generateOutput.
- processRow = generateProcessRow(newInputAttributes)
- generateOutput = generateResultProjection()
+ val newFunctions = initializeAggregateFunctions(newExpressions, 0)
+ val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes)
+ sortBasedProcessRow = generateProcessRow(newExpressions, newFunctions, newInputAttributes)
// Step 5: Get the sorted iterator from the externalSorter.
sortedKVIterator = externalSorter.sortedIterator()
@@ -632,6 +301,9 @@ class TungstenAggregationIterator(
// The aggregation buffer used by the sort-based aggregation.
private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer()
+ // The function used to process rows in a group
+ private[this] var sortBasedProcessRow: (MutableRow, InternalRow) => Unit = null
+
// Processes rows in the current group. It will stop when it find a new group.
private def processCurrentSortedGroup(): Unit = {
// First, we need to copy nextGroupingKey to currentGroupingKey.
@@ -640,7 +312,7 @@ class TungstenAggregationIterator(
// We create a variable to track if we see the next group.
var findNextPartition = false
// firstRowInNextGroup is the first row of this group. We first process it.
- processRow(sortBasedAggregationBuffer, firstRowInNextGroup)
+ sortBasedProcessRow(sortBasedAggregationBuffer, firstRowInNextGroup)
// The search will stop when we see the next group or there is no
// input row left in the iter.
@@ -655,16 +327,15 @@ class TungstenAggregationIterator(
// Check if the current row belongs the current input row.
if (currentGroupingKey.equals(groupingKey)) {
- processRow(sortBasedAggregationBuffer, inputAggregationBuffer)
+ sortBasedProcessRow(sortBasedAggregationBuffer, inputAggregationBuffer)
hasNext = sortedKVIterator.next()
} else {
// We find a new group.
findNextPartition = true
// copyFrom will fail when
- nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy()
- firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = inputAggregationBuffer.copy()
-
+ nextGroupingKey.copyFrom(groupingKey)
+ firstRowInNextGroup.copyFrom(inputAggregationBuffer)
}
}
// We have not seen a new group. It means that there is no new row in the input
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 76b938cdb6..83379ae90f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -42,16 +42,45 @@ object Utils {
SortBasedAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
groupingExpressions = groupingAttributes,
- nonCompleteAggregateExpressions = Nil,
- nonCompleteAggregateAttributes = Nil,
- completeAggregateExpressions = completeAggregateExpressions,
- completeAggregateAttributes = completeAggregateAttributes,
+ aggregateExpressions = completeAggregateExpressions,
+ aggregateAttributes = completeAggregateAttributes,
initialInputBufferOffset = 0,
resultExpressions = resultExpressions,
child = child
) :: Nil
}
+ private def createAggregate(
+ requiredChildDistributionExpressions: Option[Seq[Expression]] = None,
+ groupingExpressions: Seq[NamedExpression] = Nil,
+ aggregateExpressions: Seq[AggregateExpression] = Nil,
+ aggregateAttributes: Seq[Attribute] = Nil,
+ initialInputBufferOffset: Int = 0,
+ resultExpressions: Seq[NamedExpression] = Nil,
+ child: SparkPlan): SparkPlan = {
+ val usesTungstenAggregate = TungstenAggregate.supportsAggregate(
+ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
+ if (usesTungstenAggregate) {
+ TungstenAggregate(
+ requiredChildDistributionExpressions = requiredChildDistributionExpressions,
+ groupingExpressions = groupingExpressions,
+ aggregateExpressions = aggregateExpressions,
+ aggregateAttributes = aggregateAttributes,
+ initialInputBufferOffset = initialInputBufferOffset,
+ resultExpressions = resultExpressions,
+ child = child)
+ } else {
+ SortBasedAggregate(
+ requiredChildDistributionExpressions = requiredChildDistributionExpressions,
+ groupingExpressions = groupingExpressions,
+ aggregateExpressions = aggregateExpressions,
+ aggregateAttributes = aggregateAttributes,
+ initialInputBufferOffset = initialInputBufferOffset,
+ resultExpressions = resultExpressions,
+ child = child)
+ }
+ }
+
def planAggregateWithoutDistinct(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
@@ -59,9 +88,6 @@ object Utils {
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
// Check if we can use TungstenAggregate.
- val usesTungstenAggregate = TungstenAggregate.supportsAggregate(
- groupingExpressions,
- aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
// 1. Create an Aggregate Operator for partial aggregations.
@@ -73,29 +99,14 @@ object Utils {
groupingAttributes ++
partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
- val partialAggregate = if (usesTungstenAggregate) {
- TungstenAggregate(
- requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+ val partialAggregate = createAggregate(
+ requiredChildDistributionExpressions = None,
groupingExpressions = groupingExpressions,
- nonCompleteAggregateExpressions = partialAggregateExpressions,
- nonCompleteAggregateAttributes = partialAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
+ aggregateExpressions = partialAggregateExpressions,
+ aggregateAttributes = partialAggregateAttributes,
initialInputBufferOffset = 0,
resultExpressions = partialResultExpressions,
child = child)
- } else {
- SortBasedAggregate(
- requiredChildDistributionExpressions = None: Option[Seq[Expression]],
- groupingExpressions = groupingExpressions,
- nonCompleteAggregateExpressions = partialAggregateExpressions,
- nonCompleteAggregateAttributes = partialAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- initialInputBufferOffset = 0,
- resultExpressions = partialResultExpressions,
- child = child)
- }
// 2. Create an Aggregate Operator for final aggregations.
val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
@@ -105,29 +116,14 @@ object Utils {
expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
}
- val finalAggregate = if (usesTungstenAggregate) {
- TungstenAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes,
- nonCompleteAggregateExpressions = finalAggregateExpressions,
- nonCompleteAggregateAttributes = finalAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- initialInputBufferOffset = groupingExpressions.length,
- resultExpressions = resultExpressions,
- child = partialAggregate)
- } else {
- SortBasedAggregate(
+ val finalAggregate = createAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
groupingExpressions = groupingAttributes,
- nonCompleteAggregateExpressions = finalAggregateExpressions,
- nonCompleteAggregateAttributes = finalAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
+ aggregateExpressions = finalAggregateExpressions,
+ aggregateAttributes = finalAggregateAttributes,
initialInputBufferOffset = groupingExpressions.length,
resultExpressions = resultExpressions,
child = partialAggregate)
- }
finalAggregate :: Nil
}
@@ -140,99 +136,99 @@ object Utils {
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
- val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct
- val usesTungstenAggregate = TungstenAggregate.supportsAggregate(
- groupingExpressions,
- aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
-
// functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one
// DISTINCT aggregate function, all of those functions will have the same column expressions.
// For example, it would be valid for functionsWithDistinct to be
// [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is
// disallowed because those two distinct aggregates have different column expressions.
- val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
- val namedDistinctColumnExpressions = distinctColumnExpressions.map {
+ val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children
+ val namedDistinctExpressions = distinctExpressions.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}
- val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute)
+ val distinctAttributes = namedDistinctExpressions.map(_.toAttribute)
val groupingAttributes = groupingExpressions.map(_.toAttribute)
// 1. Create an Aggregate Operator for partial aggregations.
val partialAggregate: SparkPlan = {
- val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
- val partialAggregateAttributes =
- partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+ val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
+ val aggregateAttributes = aggregateExpressions.map {
+ expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
+ }
// We will group by the original grouping expression, plus an additional expression for the
// DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
// expressions will be [key, value].
- val partialAggregateGroupingExpressions =
- groupingExpressions ++ namedDistinctColumnExpressions
- val partialAggregateResult =
- groupingAttributes ++
- distinctColumnAttributes ++
- partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
- if (usesTungstenAggregate) {
- TungstenAggregate(
- requiredChildDistributionExpressions = None,
- groupingExpressions = partialAggregateGroupingExpressions,
- nonCompleteAggregateExpressions = partialAggregateExpressions,
- nonCompleteAggregateAttributes = partialAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- initialInputBufferOffset = 0,
- resultExpressions = partialAggregateResult,
- child = child)
- } else {
- SortBasedAggregate(
- requiredChildDistributionExpressions = None,
- groupingExpressions = partialAggregateGroupingExpressions,
- nonCompleteAggregateExpressions = partialAggregateExpressions,
- nonCompleteAggregateAttributes = partialAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- initialInputBufferOffset = 0,
- resultExpressions = partialAggregateResult,
- child = child)
- }
+ createAggregate(
+ groupingExpressions = groupingExpressions ++ namedDistinctExpressions,
+ aggregateExpressions = aggregateExpressions,
+ aggregateAttributes = aggregateAttributes,
+ resultExpressions = groupingAttributes ++ distinctAttributes ++
+ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+ child = child)
}
// 2. Create an Aggregate Operator for partial merge aggregations.
val partialMergeAggregate: SparkPlan = {
- val partialMergeAggregateExpressions =
- functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
- val partialMergeAggregateAttributes =
- partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
- val partialMergeAggregateResult =
- groupingAttributes ++
- distinctColumnAttributes ++
- partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
- if (usesTungstenAggregate) {
- TungstenAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
- nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
- nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
- resultExpressions = partialMergeAggregateResult,
- child = partialAggregate)
- } else {
- SortBasedAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
- nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
- nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
- resultExpressions = partialMergeAggregateResult,
- child = partialAggregate)
+ val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+ val aggregateAttributes = aggregateExpressions.map {
+ expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
}
+ createAggregate(
+ requiredChildDistributionExpressions =
+ Some(groupingAttributes ++ distinctAttributes),
+ groupingExpressions = groupingAttributes ++ distinctAttributes,
+ aggregateExpressions = aggregateExpressions,
+ aggregateAttributes = aggregateAttributes,
+ initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length,
+ resultExpressions = groupingAttributes ++ distinctAttributes ++
+ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+ child = partialAggregate)
+ }
+
+ // 3. Create an Aggregate operator for partial aggregation (for distinct)
+ val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap
+ val rewrittenDistinctFunctions = functionsWithDistinct.map {
+ // Children of an AggregateFunction with DISTINCT keyword has already
+ // been evaluated. At here, we need to replace original children
+ // to AttributeReferences.
+ case agg @ AggregateExpression(aggregateFunction, mode, true) =>
+ aggregateFunction.transformDown(distinctColumnAttributeLookup)
+ .asInstanceOf[AggregateFunction]
}
- // 3. Create an Aggregate Operator for the final aggregation.
+ val partialDistinctAggregate: SparkPlan = {
+ val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+ // The attributes of the final aggregation buffer, which is presented as input to the result
+ // projection:
+ val mergeAggregateAttributes = mergeAggregateExpressions.map {
+ expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
+ }
+ val (distinctAggregateExpressions, distinctAggregateAttributes) =
+ rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
+ // We rewrite the aggregate function to a non-distinct aggregation because
+ // its input will have distinct arguments.
+ // We just keep the isDistinct setting to true, so when users look at the query plan,
+ // they still can see distinct aggregations.
+ val expr = AggregateExpression(func, Partial, isDistinct = true)
+ // Use original AggregationFunction to lookup attributes, which is used to build
+ // aggregateFunctionToAttribute
+ val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true)
+ (expr, attr)
+ }.unzip
+
+ val partialAggregateResult = groupingAttributes ++
+ mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++
+ distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+ createAggregate(
+ groupingExpressions = groupingAttributes,
+ aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions,
+ aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes,
+ initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length,
+ resultExpressions = partialAggregateResult,
+ child = partialMergeAggregate)
+ }
+
+ // 4. Create an Aggregate Operator for the final aggregation.
val finalAndCompleteAggregate: SparkPlan = {
val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
// The attributes of the final aggregation buffer, which is presented as input to the result
@@ -241,49 +237,27 @@ object Utils {
expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
}
- val distinctColumnAttributeLookup =
- distinctColumnExpressions.zip(distinctColumnAttributes).toMap
- val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
- // Children of an AggregateFunction with DISTINCT keyword has already
- // been evaluated. At here, we need to replace original children
- // to AttributeReferences.
- case agg @ AggregateExpression(aggregateFunction, mode, true) =>
- val rewrittenAggregateFunction = aggregateFunction
- .transformDown(distinctColumnAttributeLookup)
- .asInstanceOf[AggregateFunction]
+ val (distinctAggregateExpressions, distinctAggregateAttributes) =
+ rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
// We rewrite the aggregate function to a non-distinct aggregation because
// its input will have distinct arguments.
// We just keep the isDistinct setting to true, so when users look at the query plan,
// they still can see distinct aggregations.
- val rewrittenAggregateExpression =
- AggregateExpression(rewrittenAggregateFunction, Complete, isDistinct = true)
-
- val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true)
- (rewrittenAggregateExpression, aggregateFunctionAttribute)
+ val expr = AggregateExpression(func, Final, isDistinct = true)
+ // Use original AggregationFunction to lookup attributes, which is used to build
+ // aggregateFunctionToAttribute
+ val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true)
+ (expr, attr)
}.unzip
- if (usesTungstenAggregate) {
- TungstenAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes,
- nonCompleteAggregateExpressions = finalAggregateExpressions,
- nonCompleteAggregateAttributes = finalAggregateAttributes,
- completeAggregateExpressions = completeAggregateExpressions,
- completeAggregateAttributes = completeAggregateAttributes,
- initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
- resultExpressions = resultExpressions,
- child = partialMergeAggregate)
- } else {
- SortBasedAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes,
- nonCompleteAggregateExpressions = finalAggregateExpressions,
- nonCompleteAggregateAttributes = finalAggregateAttributes,
- completeAggregateExpressions = completeAggregateExpressions,
- completeAggregateAttributes = completeAggregateAttributes,
- initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
- resultExpressions = resultExpressions,
- child = partialMergeAggregate)
- }
+
+ createAggregate(
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
+ aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions,
+ aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes,
+ initialInputBufferOffset = groupingAttributes.length,
+ resultExpressions = resultExpressions,
+ child = partialDistinctAggregate)
}
finalAndCompleteAggregate :: Nil
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 064c0004b8..5550198c02 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.execution
import scala.collection.JavaConverters._
-import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
@@ -552,80 +551,73 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
}
test("single distinct column set") {
- Seq(true, false).foreach { specializeSingleDistinctAgg =>
- val conf =
- (SQLConf.SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING.key,
- specializeSingleDistinctAgg.toString)
- withSQLConf(conf) {
- // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword.
- checkAnswer(
- sqlContext.sql(
- """
- |SELECT
- | min(distinct value1),
- | sum(distinct value1),
- | avg(value1),
- | avg(value2),
- | max(distinct value1)
- |FROM agg2
- """.stripMargin),
- Row(-60, 70.0, 101.0/9.0, 5.6, 100))
-
- checkAnswer(
- sqlContext.sql(
- """
- |SELECT
- | mydoubleavg(distinct value1),
- | avg(value1),
- | avg(value2),
- | key,
- | mydoubleavg(value1 - 1),
- | mydoubleavg(distinct value1) * 0.1,
- | avg(value1 + value2)
- |FROM agg2
- |GROUP BY key
- """.stripMargin),
- Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) ::
- Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) ::
- Row(null, null, 3.0, 3, null, null, null) ::
- Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil)
-
- checkAnswer(
- sqlContext.sql(
- """
- |SELECT
- | key,
- | mydoubleavg(distinct value1),
- | mydoublesum(value2),
- | mydoublesum(distinct value1),
- | mydoubleavg(distinct value1),
- | mydoubleavg(value1)
- |FROM agg2
- |GROUP BY key
- """.stripMargin),
- Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) ::
- Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
- Row(3, null, 3.0, null, null, null) ::
- Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
-
- checkAnswer(
- sqlContext.sql(
- """
- |SELECT
- | count(value1),
- | count(*),
- | count(1),
- | count(DISTINCT value1),
- | key
- |FROM agg2
- |GROUP BY key
- """.stripMargin),
- Row(3, 3, 3, 2, 1) ::
- Row(3, 4, 4, 2, 2) ::
- Row(0, 2, 2, 0, 3) ::
- Row(3, 4, 4, 3, null) :: Nil)
- }
- }
+ // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword.
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | min(distinct value1),
+ | sum(distinct value1),
+ | avg(value1),
+ | avg(value2),
+ | max(distinct value1)
+ |FROM agg2
+ """.stripMargin),
+ Row(-60, 70.0, 101.0/9.0, 5.6, 100))
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | mydoubleavg(distinct value1),
+ | avg(value1),
+ | avg(value2),
+ | key,
+ | mydoubleavg(value1 - 1),
+ | mydoubleavg(distinct value1) * 0.1,
+ | avg(value1 + value2)
+ |FROM agg2
+ |GROUP BY key
+ """.stripMargin),
+ Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) ::
+ Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) ::
+ Row(null, null, 3.0, 3, null, null, null) ::
+ Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | mydoubleavg(distinct value1),
+ | mydoublesum(value2),
+ | mydoublesum(distinct value1),
+ | mydoubleavg(distinct value1),
+ | mydoubleavg(value1)
+ |FROM agg2
+ |GROUP BY key
+ """.stripMargin),
+ Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) ::
+ Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
+ Row(3, null, 3.0, null, null, null) ::
+ Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | count(value1),
+ | count(*),
+ | count(1),
+ | count(DISTINCT value1),
+ | key
+ |FROM agg2
+ |GROUP BY key
+ """.stripMargin),
+ Row(3, 3, 3, 2, 1) ::
+ Row(3, 4, 4, 2, 2) ::
+ Row(0, 2, 2, 0, 3) ::
+ Row(3, 4, 4, 3, null) :: Nil)
}
test("single distinct multiple columns set") {