aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-08-03 00:23:08 -0700
committerReynold Xin <rxin@databricks.com>2015-08-03 00:23:08 -0700
commit1ebd41b141a95ec264bd2dd50f0fe24cd459035d (patch)
tree34452a2922a4bcd00d600ad8fa0b79df09f243bb
parent98d6d9c7a996f5456eb2653bb96985a1a05f4ce1 (diff)
downloadspark-1ebd41b141a95ec264bd2dd50f0fe24cd459035d.tar.gz
spark-1ebd41b141a95ec264bd2dd50f0fe24cd459035d.tar.bz2
spark-1ebd41b141a95ec264bd2dd50f0fe24cd459035d.zip
[SPARK-9240] [SQL] Hybrid aggregate operator using unsafe row
This PR adds a base aggregation iterator `AggregationIterator`, which is used to create `SortBasedAggregationIterator` (for sort-based aggregation) and `UnsafeHybridAggregationIterator` (first it tries hash-based aggregation and falls back to the sort-based aggregation (using external sorter) if we cannot allocate memory for the map). With these two iterators, we will not need existing iterators and I am removing those. Also, we can use a single physical `Aggregate` operator and it internally determines what iterators to used. https://issues.apache.org/jira/browse/SPARK-9240 Author: Yin Huai <yhuai@databricks.com> Closes #7813 from yhuai/AggregateOperator and squashes the following commits: e317e2b [Yin Huai] Remove unnecessary change. 74d93c5 [Yin Huai] Merge remote-tracking branch 'upstream/master' into AggregateOperator ba6afbc [Yin Huai] Add a little bit more comments. c9cf3b6 [Yin Huai] update 0f1b06f [Yin Huai] Remove unnecessary code. 21fd15f [Yin Huai] Remove unnecessary change. 964f88b [Yin Huai] Implement fallback strategy. b1ea5cf [Yin Huai] wip 7fcbd87 [Yin Huai] Add a flag to control what iterator to use. 533d5b2 [Yin Huai] Prepare for fallback! 33b7022 [Yin Huai] wip bd9282b [Yin Huai] UDAFs now supports UnsafeRow. f52ee53 [Yin Huai] wip 3171f44 [Yin Huai] wip d2c45a0 [Yin Huai] wip f60cc83 [Yin Huai] Also check input schema. af32210 [Yin Huai] Check iter.hasNext before we create an iterator because the constructor of the iterato will read at least one row from a non-empty input iter. 299008c [Yin Huai] First round cleanup. 3915bac [Yin Huai] Create a base iterator class for aggregation iterators and add the initial version of the hybrid iterator.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala182
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala490
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala236
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala398
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala175
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala664
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala269
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala99
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala9
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala118
13 files changed, 1697 insertions, 973 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index d08f553cef..4abfdfe87d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -110,7 +110,11 @@ abstract class AggregateFunction2
* buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)`
* will be 2.
*/
- var mutableBufferOffset: Int = 0
+ protected var mutableBufferOffset: Int = 0
+
+ def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = {
+ mutableBufferOffset = newMutableBufferOffset
+ }
/**
* The offset of this function's start buffer value in the
@@ -126,7 +130,11 @@ abstract class AggregateFunction2
* buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)`
* will be 3 (position 0 is used for the value of key`).
*/
- var inputBufferOffset: Int = 0
+ protected var inputBufferOffset: Int = 0
+
+ def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = {
+ inputBufferOffset = newInputBufferOffset
+ }
/** The schema of the aggregation buffer. */
def bufferSchema: StructType
@@ -195,11 +203,8 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w
override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)
override def initialize(buffer: MutableRow): Unit = {
- var i = 0
- while (i < bufferAttributes.size) {
- buffer(i + mutableBufferOffset) = initialValues(i).eval()
- i += 1
- }
+ throw new UnsupportedOperationException(
+ "AlgebraicAggregate's initialize should not be called directly")
}
override final def update(buffer: MutableRow, input: InternalRow): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala
new file mode 100644
index 0000000000..cf568dc048
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * An Aggregate Operator used to evaluate [[AggregateFunction2]]. Based on the data types
+ * of the grouping expressions and aggregate functions, it determines if it uses
+ * sort-based aggregation and hybrid (hash-based with sort-based as the fallback) to
+ * process input rows.
+ */
+case class Aggregate(
+ requiredChildDistributionExpressions: Option[Seq[Expression]],
+ groupingExpressions: Seq[NamedExpression],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryNode {
+
+ private[this] val allAggregateExpressions =
+ nonCompleteAggregateExpressions ++ completeAggregateExpressions
+
+ private[this] val hasNonAlgebricAggregateFunctions =
+ !allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])
+
+ // Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of
+ // grouping key and aggregation buffer is supported; and (3) all
+ // aggregate functions are algebraic.
+ private[this] val supportsHybridIterator: Boolean = {
+ val aggregationBufferSchema: StructType =
+ StructType.fromAttributes(
+ allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
+ val groupKeySchema: StructType =
+ StructType.fromAttributes(groupingExpressions.map(_.toAttribute))
+
+ val schemaSupportsUnsafe: Boolean =
+ UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
+ UnsafeProjection.canSupport(groupKeySchema)
+
+ // TODO: Use the hybrid iterator for non-algebric aggregate functions.
+ sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && !hasNonAlgebricAggregateFunctions
+ }
+
+ // We need to use sorted input if we have grouping expressions, and
+ // we cannot use the hybrid iterator or the hybrid is disabled.
+ private[this] val requiresSortedInput: Boolean = {
+ groupingExpressions.nonEmpty && !supportsHybridIterator
+ }
+
+ override def canProcessUnsafeRows: Boolean = !hasNonAlgebricAggregateFunctions
+
+ // If result expressions' data types are all fixed length, we generate unsafe rows
+ // (We have this requirement instead of check the result of UnsafeProjection.canSupport
+ // is because we use a mutable projection to generate the result).
+ override def outputsUnsafeRows: Boolean = {
+ // resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength)
+ // TODO: Supports generating UnsafeRows. We can just re-enable the line above and fix
+ // any issue we get.
+ false
+ }
+
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ override def requiredChildDistribution: List[Distribution] = {
+ requiredChildDistributionExpressions match {
+ case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+ case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+ case None => UnspecifiedDistribution :: Nil
+ }
+ }
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+ if (requiresSortedInput) {
+ // TODO: We should not sort the input rows if they are just in reversed order.
+ groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+ } else {
+ Seq.fill(children.size)(Nil)
+ }
+ }
+
+ override def outputOrdering: Seq[SortOrder] = {
+ if (requiresSortedInput) {
+ // It is possible that the child.outputOrdering starts with the required
+ // ordering expressions (e.g. we require [a] as the sort expression and the
+ // child's outputOrdering is [a, b]). We can only guarantee the output rows
+ // are sorted by values of groupingExpressions.
+ groupingExpressions.map(SortOrder(_, Ascending))
+ } else {
+ Nil
+ }
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ child.execute().mapPartitions { iter =>
+ // Because the constructor of an aggregation iterator will read at least the first row,
+ // we need to get the value of iter.hasNext first.
+ val hasInput = iter.hasNext
+ val useHybridIterator =
+ hasInput &&
+ supportsHybridIterator &&
+ groupingExpressions.nonEmpty
+ if (useHybridIterator) {
+ UnsafeHybridAggregationIterator.createFromInputIterator(
+ groupingExpressions,
+ nonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection _,
+ child.output,
+ iter,
+ outputsUnsafeRows)
+ } else {
+ if (!hasInput && groupingExpressions.nonEmpty) {
+ // This is a grouped aggregate and the input iterator is empty,
+ // so return an empty iterator.
+ Iterator[InternalRow]()
+ } else {
+ val outputIter = SortBasedAggregationIterator.createFromInputIterator(
+ groupingExpressions,
+ nonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection _ ,
+ newProjection _,
+ child.output,
+ iter,
+ outputsUnsafeRows)
+ if (!hasInput && groupingExpressions.isEmpty) {
+ // There is no input and there is no grouping expressions.
+ // We need to output a single row as the output.
+ Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
+ } else {
+ outputIter
+ }
+ }
+ }
+ }
+ }
+
+ override def simpleString: String = {
+ val iterator = if (supportsHybridIterator && groupingExpressions.nonEmpty) {
+ classOf[UnsafeHybridAggregationIterator].getSimpleName
+ } else {
+ classOf[SortBasedAggregationIterator].getSimpleName
+ }
+
+ s"""NewAggregate with $iterator ${groupingExpressions} ${allAggregateExpressions}"""
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
new file mode 100644
index 0000000000..abca373b0c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -0,0 +1,490 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.unsafe.KVIterator
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * The base class of [[SortBasedAggregationIterator]] and [[UnsafeHybridAggregationIterator]].
+ * It mainly contains two parts:
+ * 1. It initializes aggregate functions.
+ * 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of
+ * its aggregate functions. `processRow` is the function to handle an input. `generateOutput`
+ * is used to generate result.
+ */
+abstract class AggregationIterator(
+ groupingKeyAttributes: Seq[Attribute],
+ valueAttributes: Seq[Attribute],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ outputsUnsafeRows: Boolean)
+ extends Iterator[InternalRow] with Logging {
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Initializing functions.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // An Seq of all AggregateExpressions.
+ // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final
+ // are at the beginning of the allAggregateExpressions.
+ protected val allAggregateExpressions =
+ nonCompleteAggregateExpressions ++ completeAggregateExpressions
+
+ require(
+ allAggregateExpressions.map(_.mode).distinct.length <= 2,
+ s"$allAggregateExpressions are not supported becuase they have more than 2 distinct modes.")
+
+ /**
+ * The distinct modes of AggregateExpressions. Right now, we can handle the following mode:
+ * - Partial-only: all AggregateExpressions have the mode of Partial;
+ * - PartialMerge-only: all AggregateExpressions have the mode of PartialMerge);
+ * - Final-only: all AggregateExpressions have the mode of Final;
+ * - Final-Complete: some AggregateExpressions have the mode of Final and
+ * others have the mode of Complete;
+ * - Complete-only: nonCompleteAggregateExpressions is empty and we have AggregateExpressions
+ * with mode Complete in completeAggregateExpressions; and
+ * - Grouping-only: there is no AggregateExpression.
+ */
+ protected val aggregationMode: (Option[AggregateMode], Option[AggregateMode]) =
+ nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
+ completeAggregateExpressions.map(_.mode).distinct.headOption
+
+ // Initialize all AggregateFunctions by binding references if necessary,
+ // and set inputBufferOffset and mutableBufferOffset.
+ protected val allAggregateFunctions: Array[AggregateFunction2] = {
+ var mutableBufferOffset = 0
+ var inputBufferOffset: Int = initialInputBufferOffset
+ val functions = new Array[AggregateFunction2](allAggregateExpressions.length)
+ var i = 0
+ while (i < allAggregateExpressions.length) {
+ val func = allAggregateExpressions(i).aggregateFunction
+ val funcWithBoundReferences = allAggregateExpressions(i).mode match {
+ case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] =>
+ // We need to create BoundReferences if the function is not an
+ // AlgebraicAggregate (it does not support code-gen) and the mode of
+ // this function is Partial or Complete because we will call eval of this
+ // function's children in the update method of this aggregate function.
+ // Those eval calls require BoundReferences to work.
+ BindReferences.bindReference(func, valueAttributes)
+ case _ =>
+ // We only need to set inputBufferOffset for aggregate functions with mode
+ // PartialMerge and Final.
+ func.withNewInputBufferOffset(inputBufferOffset)
+ inputBufferOffset += func.bufferSchema.length
+ func
+ }
+ // Set mutableBufferOffset for this function. It is important that setting
+ // mutableBufferOffset happens after all potential bindReference operations
+ // because bindReference will create a new instance of the function.
+ funcWithBoundReferences.withNewMutableBufferOffset(mutableBufferOffset)
+ mutableBufferOffset += funcWithBoundReferences.bufferSchema.length
+ functions(i) = funcWithBoundReferences
+ i += 1
+ }
+ functions
+ }
+
+ // Positions of those non-algebraic aggregate functions in allAggregateFunctions.
+ // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
+ // func2 and func3 are non-algebraic aggregate functions.
+ // nonAlgebraicAggregateFunctionPositions will be [1, 2].
+ private[this] val allNonAlgebraicAggregateFunctionPositions: Array[Int] = {
+ val positions = new ArrayBuffer[Int]()
+ var i = 0
+ while (i < allAggregateFunctions.length) {
+ allAggregateFunctions(i) match {
+ case agg: AlgebraicAggregate =>
+ case _ => positions += i
+ }
+ i += 1
+ }
+ positions.toArray
+ }
+
+ // All AggregateFunctions functions with mode Partial, PartialMerge, or Final.
+ private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction2] =
+ allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
+
+ // All non-algebraic aggregate functions with mode Partial, PartialMerge, or Final.
+ private[this] val nonCompleteNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
+ nonCompleteAggregateFunctions.collect {
+ case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+ }
+
+ // The projection used to initialize buffer values for all AlgebraicAggregates.
+ private[this] val algebraicInitialProjection = {
+ val initExpressions = allAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.initialValues
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ newMutableProjection(initExpressions, Nil)()
+ }
+
+ // All non-Algebraic AggregateFunctions.
+ private[this] val allNonAlgebraicAggregateFunctions =
+ allNonAlgebraicAggregateFunctionPositions.map(allAggregateFunctions)
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Methods and fields used by sub-classes.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // Initializing functions used to process a row.
+ protected val processRow: (MutableRow, InternalRow) => Unit = {
+ val rowToBeProcessed = new JoinedRow
+ val aggregationBufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+ aggregationMode match {
+ // Partial-only
+ case (Some(Partial), None) =>
+ val updateExpressions = nonCompleteAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ val algebraicUpdateProjection =
+ newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
+
+ (currentBuffer: MutableRow, row: InternalRow) => {
+ algebraicUpdateProjection.target(currentBuffer)
+ // Process all algebraic aggregate functions.
+ algebraicUpdateProjection(rowToBeProcessed(currentBuffer, row))
+ // Process all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
+ nonCompleteNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
+ i += 1
+ }
+ }
+
+ // PartialMerge-only or Final-only
+ case (Some(PartialMerge), None) | (Some(Final), None) =>
+ val inputAggregationBufferSchema = if (initialInputBufferOffset == 0) {
+ // If initialInputBufferOffset, the input value does not contain
+ // grouping keys.
+ // This part is pretty hacky.
+ allAggregateFunctions.flatMap(_.cloneBufferAttributes).toSeq
+ } else {
+ groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.cloneBufferAttributes)
+ }
+ // val inputAggregationBufferSchema =
+ // groupingKeyAttributes ++
+ // allAggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val mergeExpressions = nonCompleteAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ // This projection is used to merge buffer values for all AlgebraicAggregates.
+ val algebraicMergeProjection =
+ newMutableProjection(
+ mergeExpressions,
+ aggregationBufferSchema ++ inputAggregationBufferSchema)()
+
+ (currentBuffer: MutableRow, row: InternalRow) => {
+ // Process all algebraic aggregate functions.
+ algebraicMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row))
+ // Process all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
+ nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row)
+ i += 1
+ }
+ }
+
+ // Final-Complete
+ case (Some(Final), Some(Complete)) =>
+ val completeAggregateFunctions: Array[AggregateFunction2] =
+ allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+ // All non-algebraic aggregate functions with mode Complete.
+ val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
+ completeAggregateFunctions.collect {
+ case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+ }
+
+ // The first initialInputBufferOffset values of the input aggregation buffer is
+ // for grouping expressions and distinct columns.
+ val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset)
+
+ val completeOffsetExpressions =
+ Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+ // We do not touch buffer values of aggregate functions with the Final mode.
+ val finalOffsetExpressions =
+ Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+
+ val mergeInputSchema =
+ aggregationBufferSchema ++
+ groupingAttributesAndDistinctColumns ++
+ nonCompleteAggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val mergeExpressions =
+ nonCompleteAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ } ++ completeOffsetExpressions
+ val finalAlgebraicMergeProjection =
+ newMutableProjection(mergeExpressions, mergeInputSchema)()
+
+ val updateExpressions =
+ finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ val completeAlgebraicUpdateProjection =
+ newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
+
+ (currentBuffer: MutableRow, row: InternalRow) => {
+ val input = rowToBeProcessed(currentBuffer, row)
+ // For all aggregate functions with mode Complete, update buffers.
+ completeAlgebraicUpdateProjection.target(currentBuffer)(input)
+ var i = 0
+ while (i < completeNonAlgebraicAggregateFunctions.length) {
+ completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
+ i += 1
+ }
+
+ // For all aggregate functions with mode Final, merge buffers.
+ finalAlgebraicMergeProjection.target(currentBuffer)(input)
+ i = 0
+ while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
+ nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row)
+ i += 1
+ }
+ }
+
+ // Complete-only
+ case (None, Some(Complete)) =>
+ val completeAggregateFunctions: Array[AggregateFunction2] =
+ allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+ // All non-algebraic aggregate functions with mode Complete.
+ val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
+ completeAggregateFunctions.collect {
+ case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+ }
+
+ val updateExpressions =
+ completeAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ val completeAlgebraicUpdateProjection =
+ newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
+
+ (currentBuffer: MutableRow, row: InternalRow) => {
+ val input = rowToBeProcessed(currentBuffer, row)
+ // For all aggregate functions with mode Complete, update buffers.
+ completeAlgebraicUpdateProjection.target(currentBuffer)(input)
+ var i = 0
+ while (i < completeNonAlgebraicAggregateFunctions.length) {
+ completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
+ i += 1
+ }
+ }
+
+ // Grouping only.
+ case (None, None) => (currentBuffer: MutableRow, row: InternalRow) => {}
+
+ case other =>
+ sys.error(
+ s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " +
+ s"support evaluate modes $other in this iterator.")
+ }
+ }
+
+ // Initializing the function used to generate the output row.
+ protected val generateOutput: (InternalRow, MutableRow) => InternalRow = {
+ val rowToBeEvaluated = new JoinedRow
+ val safeOutoutRow = new GenericMutableRow(resultExpressions.length)
+ val mutableOutput = if (outputsUnsafeRows) {
+ UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutoutRow)
+ } else {
+ safeOutoutRow
+ }
+
+ aggregationMode match {
+ // Partial-only or PartialMerge-only: every output row is basically the values of
+ // the grouping expressions and the corresponding aggregation buffer.
+ case (Some(Partial), None) | (Some(PartialMerge), None) =>
+ // Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not
+ // support generic getter), we create a mutable projection to output the
+ // JoinedRow(currentGroupingKey, currentBuffer)
+ val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.bufferAttributes)
+ val resultProjection =
+ newMutableProjection(
+ groupingKeyAttributes ++ bufferSchema,
+ groupingKeyAttributes ++ bufferSchema)()
+ resultProjection.target(mutableOutput)
+
+ (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
+ resultProjection(rowToBeEvaluated(currentGroupingKey, currentBuffer))
+ // rowToBeEvaluated(currentGroupingKey, currentBuffer)
+ }
+
+ // Final-only, Complete-only and Final-Complete: every output row contains values representing
+ // resultExpressions.
+ case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
+ val bufferSchemata =
+ allAggregateFunctions.flatMap(_.bufferAttributes)
+ val evalExpressions = allAggregateFunctions.map {
+ case ae: AlgebraicAggregate => ae.evaluateExpression
+ case agg: AggregateFunction2 => NoOp
+ }
+ val algebraicEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)()
+ val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
+ // TODO: Use unsafe row.
+ val aggregateResult = new GenericMutableRow(aggregateResultSchema.length)
+ val resultProjection =
+ newMutableProjection(
+ resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)()
+ resultProjection.target(mutableOutput)
+
+ (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
+ // Generate results for all algebraic aggregate functions.
+ algebraicEvalProjection.target(aggregateResult)(currentBuffer)
+ // Generate results for all non-algebraic aggregate functions.
+ var i = 0
+ while (i < allNonAlgebraicAggregateFunctions.length) {
+ aggregateResult.update(
+ allNonAlgebraicAggregateFunctionPositions(i),
+ allNonAlgebraicAggregateFunctions(i).eval(currentBuffer))
+ i += 1
+ }
+ resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult))
+ }
+
+ // Grouping-only: we only output values of grouping expressions.
+ case (None, None) =>
+ val resultProjection =
+ newMutableProjection(resultExpressions, groupingKeyAttributes)()
+ resultProjection.target(mutableOutput)
+
+ (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
+ resultProjection(currentGroupingKey)
+ }
+
+ case other =>
+ sys.error(
+ s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " +
+ s"support evaluate modes $other in this iterator.")
+ }
+ }
+
+ /** Initializes buffer values for all aggregate functions. */
+ protected def initializeBuffer(buffer: MutableRow): Unit = {
+ algebraicInitialProjection.target(buffer)(EmptyRow)
+ var i = 0
+ while (i < allNonAlgebraicAggregateFunctions.length) {
+ allNonAlgebraicAggregateFunctions(i).initialize(buffer)
+ i += 1
+ }
+ }
+
+ /**
+ * Creates a new aggregation buffer and initializes buffer values
+ * for all aggregate functions.
+ */
+ protected def newBuffer: MutableRow
+}
+
+object AggregationIterator {
+ def kvIterator(
+ groupingExpressions: Seq[NamedExpression],
+ newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow]): KVIterator[InternalRow, InternalRow] = {
+ new KVIterator[InternalRow, InternalRow] {
+ private[this] val groupingKeyGenerator = newProjection(groupingExpressions, inputAttributes)
+
+ private[this] var groupingKey: InternalRow = _
+
+ private[this] var value: InternalRow = _
+
+ override def next(): Boolean = {
+ if (inputIter.hasNext) {
+ // Read the next input row.
+ val inputRow = inputIter.next()
+ // Get groupingKey based on groupingExpressions.
+ groupingKey = groupingKeyGenerator(inputRow)
+ // The value is the inputRow.
+ value = inputRow
+ true
+ } else {
+ false
+ }
+ }
+
+ override def getKey(): InternalRow = {
+ groupingKey
+ }
+
+ override def getValue(): InternalRow = {
+ value
+ }
+
+ override def close(): Unit = {
+ // Do nothing
+ }
+ }
+ }
+
+ def unsafeKVIterator(
+ groupingExpressions: Seq[NamedExpression],
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow]): KVIterator[UnsafeRow, InternalRow] = {
+ new KVIterator[UnsafeRow, InternalRow] {
+ private[this] val groupingKeyGenerator =
+ UnsafeProjection.create(groupingExpressions, inputAttributes)
+
+ private[this] var groupingKey: UnsafeRow = _
+
+ private[this] var value: InternalRow = _
+
+ override def next(): Boolean = {
+ if (inputIter.hasNext) {
+ // Read the next input row.
+ val inputRow = inputIter.next()
+ // Get groupingKey based on groupingExpressions.
+ groupingKey = groupingKeyGenerator.apply(inputRow)
+ // The value is the inputRow.
+ value = inputRow
+ true
+ } else {
+ false
+ }
+ }
+
+ override def getKey(): UnsafeRow = {
+ groupingKey
+ }
+
+ override def getValue(): InternalRow = {
+ value
+ }
+
+ override def close(): Unit = {
+ // Do nothing
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
new file mode 100644
index 0000000000..78bcee16c9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2}
+import org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.KVIterator
+
+/**
+ * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been
+ * sorted by values of [[groupingKeyAttributes]].
+ */
+class SortBasedAggregationIterator(
+ groupingKeyAttributes: Seq[Attribute],
+ valueAttributes: Seq[Attribute],
+ inputKVIterator: KVIterator[InternalRow, InternalRow],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ outputsUnsafeRows: Boolean)
+ extends AggregationIterator(
+ groupingKeyAttributes,
+ valueAttributes,
+ nonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection,
+ outputsUnsafeRows) {
+
+ override protected def newBuffer: MutableRow = {
+ val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+ val bufferRowSize: Int = bufferSchema.length
+
+ val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
+ val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isFixedLength)
+
+ val buffer = if (useUnsafeBuffer) {
+ val unsafeProjection =
+ UnsafeProjection.create(bufferSchema.map(_.dataType))
+ unsafeProjection.apply(genericMutableBuffer)
+ } else {
+ genericMutableBuffer
+ }
+ initializeBuffer(buffer)
+ buffer
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Mutable states for sort based aggregation.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // The partition key of the current partition.
+ private[this] var currentGroupingKey: InternalRow = _
+
+ // The partition key of next partition.
+ private[this] var nextGroupingKey: InternalRow = _
+
+ // The first row of next partition.
+ private[this] var firstRowInNextGroup: InternalRow = _
+
+ // Indicates if we has new group of rows from the sorted input iterator
+ private[this] var sortedInputHasNewGroup: Boolean = false
+
+ // The aggregation buffer used by the sort-based aggregation.
+ private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer
+
+ /** Processes rows in the current group. It will stop when it find a new group. */
+ protected def processCurrentSortedGroup(): Unit = {
+ currentGroupingKey = nextGroupingKey
+ // Now, we will start to find all rows belonging to this group.
+ // We create a variable to track if we see the next group.
+ var findNextPartition = false
+ // firstRowInNextGroup is the first row of this group. We first process it.
+ processRow(sortBasedAggregationBuffer, firstRowInNextGroup)
+
+ // The search will stop when we see the next group or there is no
+ // input row left in the iter.
+ var hasNext = inputKVIterator.next()
+ while (!findNextPartition && hasNext) {
+ // Get the grouping key.
+ val groupingKey = inputKVIterator.getKey
+ val currentRow = inputKVIterator.getValue
+
+ // Check if the current row belongs the current input row.
+ if (currentGroupingKey == groupingKey) {
+ processRow(sortBasedAggregationBuffer, currentRow)
+
+ hasNext = inputKVIterator.next()
+ } else {
+ // We find a new group.
+ findNextPartition = true
+ nextGroupingKey = groupingKey.copy()
+ firstRowInNextGroup = currentRow.copy()
+ }
+ }
+ // We have not seen a new group. It means that there is no new row in the input
+ // iter. The current group is the last group of the iter.
+ if (!findNextPartition) {
+ sortedInputHasNewGroup = false
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Iterator's public methods
+ ///////////////////////////////////////////////////////////////////////////
+
+ override final def hasNext: Boolean = sortedInputHasNewGroup
+
+ override final def next(): InternalRow = {
+ if (hasNext) {
+ // Process the current group.
+ processCurrentSortedGroup()
+ // Generate output row for the current group.
+ val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer)
+ // Initialize buffer values for the next group.
+ initializeBuffer(sortBasedAggregationBuffer)
+
+ outputRow
+ } else {
+ // no more result
+ throw new NoSuchElementException
+ }
+ }
+
+ protected def initialize(): Unit = {
+ if (inputKVIterator.next()) {
+ initializeBuffer(sortBasedAggregationBuffer)
+
+ nextGroupingKey = inputKVIterator.getKey().copy()
+ firstRowInNextGroup = inputKVIterator.getValue().copy()
+
+ sortedInputHasNewGroup = true
+ } else {
+ // This inputIter is empty.
+ sortedInputHasNewGroup = false
+ }
+ }
+
+ initialize()
+
+ def outputForEmptyGroupingKeyWithoutInput(): InternalRow = {
+ initializeBuffer(sortBasedAggregationBuffer)
+ generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer)
+ }
+}
+
+object SortBasedAggregationIterator {
+ // scalastyle:off
+ def createFromInputIterator(
+ groupingExprs: Seq[NamedExpression],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow],
+ outputsUnsafeRows: Boolean): SortBasedAggregationIterator = {
+ val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) {
+ AggregationIterator.unsafeKVIterator(
+ groupingExprs,
+ inputAttributes,
+ inputIter).asInstanceOf[KVIterator[InternalRow, InternalRow]]
+ } else {
+ AggregationIterator.kvIterator(groupingExprs, newProjection, inputAttributes, inputIter)
+ }
+
+ new SortBasedAggregationIterator(
+ groupingExprs.map(_.toAttribute),
+ inputAttributes,
+ kvIterator,
+ nonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection,
+ outputsUnsafeRows)
+ }
+
+ def createFromKVIterator(
+ groupingKeyAttributes: Seq[Attribute],
+ valueAttributes: Seq[Attribute],
+ inputKVIterator: KVIterator[InternalRow, InternalRow],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ outputsUnsafeRows: Boolean): SortBasedAggregationIterator = {
+ new SortBasedAggregationIterator(
+ groupingKeyAttributes,
+ valueAttributes,
+ inputKVIterator,
+ nonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection,
+ outputsUnsafeRows)
+ }
+ // scalastyle:on
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
new file mode 100644
index 0000000000..37d34eb7cc
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
@@ -0,0 +1,398 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.execution.{UnsafeKeyValueSorter, UnsafeFixedWidthAggregationMap}
+import org.apache.spark.unsafe.KVIterator
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.types.StructType
+
+/**
+ * An iterator used to evaluate [[AggregateFunction2]].
+ * It first tries to use in-memory hash-based aggregation. If we cannot allocate more
+ * space for the hash map, we spill the sorted map entries, free the map, and then
+ * switch to sort-based aggregation.
+ */
+class UnsafeHybridAggregationIterator(
+ groupingKeyAttributes: Seq[Attribute],
+ valueAttributes: Seq[Attribute],
+ inputKVIterator: KVIterator[UnsafeRow, InternalRow],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ outputsUnsafeRows: Boolean)
+ extends AggregationIterator(
+ groupingKeyAttributes,
+ valueAttributes,
+ nonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection,
+ outputsUnsafeRows) {
+
+ require(groupingKeyAttributes.nonEmpty)
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Unsafe Aggregation buffers
+ ///////////////////////////////////////////////////////////////////////////
+
+ // This is the Unsafe Aggregation Map used to store all buffers.
+ private[this] val buffers = new UnsafeFixedWidthAggregationMap(
+ newBuffer,
+ StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)),
+ StructType.fromAttributes(groupingKeyAttributes),
+ TaskContext.get.taskMemoryManager(),
+ SparkEnv.get.shuffleMemoryManager,
+ 1024 * 16, // initial capacity
+ SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"),
+ false // disable tracking of performance metrics
+ )
+
+ override protected def newBuffer: UnsafeRow = {
+ val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+ val bufferRowSize: Int = bufferSchema.length
+
+ val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
+ val unsafeProjection =
+ UnsafeProjection.create(bufferSchema.map(_.dataType))
+ val buffer = unsafeProjection.apply(genericMutableBuffer)
+ initializeBuffer(buffer)
+ buffer
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Methods and variables related to switching to sort-based aggregation
+ ///////////////////////////////////////////////////////////////////////////
+ private[this] var sortBased = false
+
+ private[this] var sortBasedAggregationIterator: SortBasedAggregationIterator = _
+
+ // The value part of the input KV iterator is used to store original input values of
+ // aggregate functions, we need to convert them to aggregation buffers.
+ private def processOriginalInput(
+ firstKey: UnsafeRow,
+ firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = {
+ new KVIterator[UnsafeRow, UnsafeRow] {
+ private[this] var isFirstRow = true
+
+ private[this] var groupingKey: UnsafeRow = _
+
+ private[this] val buffer: UnsafeRow = newBuffer
+
+ override def next(): Boolean = {
+ initializeBuffer(buffer)
+ if (isFirstRow) {
+ isFirstRow = false
+ groupingKey = firstKey
+ processRow(buffer, firstValue)
+
+ true
+ } else if (inputKVIterator.next()) {
+ groupingKey = inputKVIterator.getKey()
+ val value = inputKVIterator.getValue()
+ processRow(buffer, value)
+
+ true
+ } else {
+ false
+ }
+ }
+
+ override def getKey(): UnsafeRow = {
+ groupingKey
+ }
+
+ override def getValue(): UnsafeRow = {
+ buffer
+ }
+
+ override def close(): Unit = {
+ // Do nothing.
+ }
+ }
+ }
+
+ // The value of the input KV Iterator has the format of groupingExprs + aggregation buffer.
+ // We need to project the aggregation buffer out.
+ private def projectInputBufferToUnsafe(
+ firstKey: UnsafeRow,
+ firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = {
+ new KVIterator[UnsafeRow, UnsafeRow] {
+ private[this] var isFirstRow = true
+
+ private[this] var groupingKey: UnsafeRow = _
+
+ private[this] val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+
+ private[this] val value: UnsafeRow = {
+ val genericMutableRow = new GenericMutableRow(bufferSchema.length)
+ UnsafeProjection.create(bufferSchema.map(_.dataType)).apply(genericMutableRow)
+ }
+
+ private[this] val projectInputBuffer = {
+ newMutableProjection(bufferSchema, valueAttributes)().target(value)
+ }
+
+ override def next(): Boolean = {
+ if (isFirstRow) {
+ isFirstRow = false
+ groupingKey = firstKey
+ projectInputBuffer(firstValue)
+
+ true
+ } else if (inputKVIterator.next()) {
+ groupingKey = inputKVIterator.getKey()
+ projectInputBuffer(inputKVIterator.getValue())
+
+ true
+ } else {
+ false
+ }
+ }
+
+ override def getKey(): UnsafeRow = {
+ groupingKey
+ }
+
+ override def getValue(): UnsafeRow = {
+ value
+ }
+
+ override def close(): Unit = {
+ // Do nothing.
+ }
+ }
+ }
+
+ /**
+ * We need to fall back to sort based aggregation because we do not have enough memory
+ * for our in-memory hash map (i.e. `buffers`).
+ */
+ private def switchToSortBasedAggregation(
+ currentGroupingKey: UnsafeRow,
+ currentRow: InternalRow): Unit = {
+ logInfo("falling back to sort based aggregation.")
+
+ // Step 1: Get the ExternalSorter containing entries of the map.
+ val externalSorter = buffers.destructAndCreateExternalSorter()
+
+ // Step 2: Free the memory used by the map.
+ buffers.free()
+
+ // Step 3: If we have aggregate function with mode Partial or Complete,
+ // we need to process them to get aggregation buffer.
+ // So, later in the sort-based aggregation iterator, we can do merge.
+ // If aggregate functions are with mode Final and PartialMerge,
+ // we just need to project the aggregation buffer from the input.
+ val needsProcess = aggregationMode match {
+ case (Some(Partial), None) => true
+ case (None, Some(Complete)) => true
+ case (Some(Final), Some(Complete)) => true
+ case _ => false
+ }
+
+ val processedIterator = if (needsProcess) {
+ processOriginalInput(currentGroupingKey, currentRow)
+ } else {
+ // The input value's format is groupingExprs + buffer.
+ // We need to project the buffer part out.
+ projectInputBufferToUnsafe(currentGroupingKey, currentRow)
+ }
+
+ // Step 4: Redirect processedIterator to externalSorter.
+ while (processedIterator.next()) {
+ externalSorter.insertKV(processedIterator.getKey(), processedIterator.getValue())
+ }
+
+ // Step 5: Get the sorted iterator from the externalSorter.
+ val sortedKVIterator: KVIterator[UnsafeRow, UnsafeRow] = externalSorter.sortedIterator()
+
+ // Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator.
+ // For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator
+ // will be PartialMerge. For a aggregate function with mode Complete,
+ // its mode in the SortBasedAggregationIterator will be Final.
+ val newNonCompleteAggregateExpressions = allAggregateExpressions.map {
+ case AggregateExpression2(func, Partial, isDistinct) =>
+ AggregateExpression2(func, PartialMerge, isDistinct)
+ case AggregateExpression2(func, Complete, isDistinct) =>
+ AggregateExpression2(func, Final, isDistinct)
+ case other => other
+ }
+ val newNonCompleteAggregateAttributes =
+ nonCompleteAggregateAttributes ++ completeAggregateAttributes
+
+ val newValueAttributes =
+ allAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+
+ sortBasedAggregationIterator = SortBasedAggregationIterator.createFromKVIterator(
+ groupingKeyAttributes = groupingKeyAttributes,
+ valueAttributes = newValueAttributes,
+ inputKVIterator = sortedKVIterator.asInstanceOf[KVIterator[InternalRow, InternalRow]],
+ nonCompleteAggregateExpressions = newNonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes = newNonCompleteAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = 0,
+ resultExpressions = resultExpressions,
+ newMutableProjection = newMutableProjection,
+ outputsUnsafeRows = outputsUnsafeRows)
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Methods used to initialize this iterator.
+ ///////////////////////////////////////////////////////////////////////////
+
+ /** Starts to read input rows and falls back to sort-based aggregation if necessary. */
+ protected def initialize(): Unit = {
+ var hasNext = inputKVIterator.next()
+ while (!sortBased && hasNext) {
+ val groupingKey = inputKVIterator.getKey()
+ val currentRow = inputKVIterator.getValue()
+ val buffer = buffers.getAggregationBuffer(groupingKey)
+ if (buffer == null) {
+ // buffer == null means that we could not allocate more memory.
+ // Now, we need to spill the map and switch to sort-based aggregation.
+ switchToSortBasedAggregation(groupingKey, currentRow)
+ sortBased = true
+ } else {
+ processRow(buffer, currentRow)
+ hasNext = inputKVIterator.next()
+ }
+ }
+ }
+
+ // This is the starting point of this iterator.
+ initialize()
+
+ // Creates the iterator for the Hash Aggregation Map after we have populated
+ // contents of that map.
+ private[this] val aggregationBufferMapIterator = buffers.iterator()
+
+ private[this] var _mapIteratorHasNext = false
+
+ // Pre-load the first key-value pair from the map to make hasNext idempotent.
+ if (!sortBased) {
+ _mapIteratorHasNext = aggregationBufferMapIterator.next()
+ // If the map is empty, we just free it.
+ if (!_mapIteratorHasNext) {
+ buffers.free()
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Iterator's public methods
+ ///////////////////////////////////////////////////////////////////////////
+
+ override final def hasNext: Boolean = {
+ (sortBased && sortBasedAggregationIterator.hasNext) || (!sortBased && _mapIteratorHasNext)
+ }
+
+
+ override final def next(): InternalRow = {
+ if (hasNext) {
+ if (sortBased) {
+ sortBasedAggregationIterator.next()
+ } else {
+ // We did not fall back to the sort-based aggregation.
+ val result =
+ generateOutput(
+ aggregationBufferMapIterator.getKey,
+ aggregationBufferMapIterator.getValue)
+ // Pre-load next key-value pair form aggregationBufferMapIterator.
+ _mapIteratorHasNext = aggregationBufferMapIterator.next()
+
+ if (!_mapIteratorHasNext) {
+ val resultCopy = result.copy()
+ buffers.free()
+ resultCopy
+ } else {
+ result
+ }
+ }
+ } else {
+ // no more result
+ throw new NoSuchElementException
+ }
+ }
+}
+
+object UnsafeHybridAggregationIterator {
+ // scalastyle:off
+ def createFromInputIterator(
+ groupingExprs: Seq[NamedExpression],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow],
+ outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = {
+ new UnsafeHybridAggregationIterator(
+ groupingExprs.map(_.toAttribute),
+ inputAttributes,
+ AggregationIterator.unsafeKVIterator(groupingExprs, inputAttributes, inputIter),
+ nonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection,
+ outputsUnsafeRows)
+ }
+
+ def createFromKVIterator(
+ groupingKeyAttributes: Seq[Attribute],
+ valueAttributes: Seq[Attribute],
+ inputKVIterator: KVIterator[UnsafeRow, InternalRow],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = {
+ new UnsafeHybridAggregationIterator(
+ groupingKeyAttributes,
+ valueAttributes,
+ inputKVIterator,
+ nonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection,
+ outputsUnsafeRows)
+ }
+ // scalastyle:on
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
deleted file mode 100644
index 98538c462b..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
+++ /dev/null
@@ -1,175 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
-
-case class Aggregate2Sort(
- requiredChildDistributionExpressions: Option[Seq[Expression]],
- groupingExpressions: Seq[NamedExpression],
- aggregateExpressions: Seq[AggregateExpression2],
- aggregateAttributes: Seq[Attribute],
- resultExpressions: Seq[NamedExpression],
- child: SparkPlan)
- extends UnaryNode {
-
- override def canProcessUnsafeRows: Boolean = true
-
- override def references: AttributeSet = {
- val referencesInResults =
- AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes)
-
- AttributeSet(
- groupingExpressions.flatMap(_.references) ++
- aggregateExpressions.flatMap(_.references) ++
- referencesInResults)
- }
-
- override def requiredChildDistribution: List[Distribution] = {
- requiredChildDistributionExpressions match {
- case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
- case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
- case None => UnspecifiedDistribution :: Nil
- }
- }
-
- override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
- // TODO: We should not sort the input rows if they are just in reversed order.
- groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
- }
-
- override def outputOrdering: Seq[SortOrder] = {
- // It is possible that the child.outputOrdering starts with the required
- // ordering expressions (e.g. we require [a] as the sort expression and the
- // child's outputOrdering is [a, b]). We can only guarantee the output rows
- // are sorted by values of groupingExpressions.
- groupingExpressions.map(SortOrder(_, Ascending))
- }
-
- override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
- protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
- child.execute().mapPartitions { iter =>
- if (aggregateExpressions.length == 0) {
- new FinalSortAggregationIterator(
- groupingExpressions,
- Nil,
- Nil,
- resultExpressions,
- newMutableProjection,
- child.output,
- iter)
- } else {
- val aggregationIterator: SortAggregationIterator = {
- aggregateExpressions.map(_.mode).distinct.toList match {
- case Partial :: Nil =>
- new PartialSortAggregationIterator(
- groupingExpressions,
- aggregateExpressions,
- newMutableProjection,
- child.output,
- iter)
- case PartialMerge :: Nil =>
- new PartialMergeSortAggregationIterator(
- groupingExpressions,
- aggregateExpressions,
- newMutableProjection,
- child.output,
- iter)
- case Final :: Nil =>
- new FinalSortAggregationIterator(
- groupingExpressions,
- aggregateExpressions,
- aggregateAttributes,
- resultExpressions,
- newMutableProjection,
- child.output,
- iter)
- case other =>
- sys.error(
- s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " +
- s"modes $other in this operator.")
- }
- }
-
- aggregationIterator
- }
- }
- }
-}
-
-case class FinalAndCompleteAggregate2Sort(
- previousGroupingExpressions: Seq[NamedExpression],
- groupingExpressions: Seq[NamedExpression],
- finalAggregateExpressions: Seq[AggregateExpression2],
- finalAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
- completeAggregateAttributes: Seq[Attribute],
- resultExpressions: Seq[NamedExpression],
- child: SparkPlan)
- extends UnaryNode {
- override def references: AttributeSet = {
- val referencesInResults =
- AttributeSet(resultExpressions.flatMap(_.references)) --
- AttributeSet(finalAggregateExpressions) --
- AttributeSet(completeAggregateExpressions)
-
- AttributeSet(
- groupingExpressions.flatMap(_.references) ++
- finalAggregateExpressions.flatMap(_.references) ++
- completeAggregateExpressions.flatMap(_.references) ++
- referencesInResults)
- }
-
- override def requiredChildDistribution: List[Distribution] = {
- if (groupingExpressions.isEmpty) {
- AllTuples :: Nil
- } else {
- ClusteredDistribution(groupingExpressions) :: Nil
- }
- }
-
- override def requiredChildOrdering: Seq[Seq[SortOrder]] =
- groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
-
- override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
- protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
- child.execute().mapPartitions { iter =>
-
- new FinalAndCompleteSortAggregationIterator(
- previousGroupingExpressions.length,
- groupingExpressions,
- finalAggregateExpressions,
- finalAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
- resultExpressions,
- newMutableProjection,
- child.output,
- iter)
- }
- }
-
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
deleted file mode 100644
index 2ca0cb82c1..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
+++ /dev/null
@@ -1,664 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.types.NullType
-
-import scala.collection.mutable.ArrayBuffer
-
-/**
- * An iterator used to evaluate aggregate functions. It assumes that input rows
- * are already grouped by values of `groupingExpressions`.
- */
-private[sql] abstract class SortAggregationIterator(
- groupingExpressions: Seq[NamedExpression],
- aggregateExpressions: Seq[AggregateExpression2],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow])
- extends Iterator[InternalRow] {
-
- ///////////////////////////////////////////////////////////////////////////
- // Static fields for this iterator
- ///////////////////////////////////////////////////////////////////////////
-
- protected val aggregateFunctions: Array[AggregateFunction2] = {
- var mutableBufferOffset = 0
- var inputBufferOffset: Int = initialInputBufferOffset
- val functions = new Array[AggregateFunction2](aggregateExpressions.length)
- var i = 0
- while (i < aggregateExpressions.length) {
- val func = aggregateExpressions(i).aggregateFunction
- val funcWithBoundReferences = aggregateExpressions(i).mode match {
- case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] =>
- // We need to create BoundReferences if the function is not an
- // AlgebraicAggregate (it does not support code-gen) and the mode of
- // this function is Partial or Complete because we will call eval of this
- // function's children in the update method of this aggregate function.
- // Those eval calls require BoundReferences to work.
- BindReferences.bindReference(func, inputAttributes)
- case _ =>
- // We only need to set inputBufferOffset for aggregate functions with mode
- // PartialMerge and Final.
- func.inputBufferOffset = inputBufferOffset
- inputBufferOffset += func.bufferSchema.length
- func
- }
- // Set mutableBufferOffset for this function. It is important that setting
- // mutableBufferOffset happens after all potential bindReference operations
- // because bindReference will create a new instance of the function.
- funcWithBoundReferences.mutableBufferOffset = mutableBufferOffset
- mutableBufferOffset += funcWithBoundReferences.bufferSchema.length
- functions(i) = funcWithBoundReferences
- i += 1
- }
- functions
- }
-
- // Positions of those non-algebraic aggregate functions in aggregateFunctions.
- // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
- // func2 and func3 are non-algebraic aggregate functions.
- // nonAlgebraicAggregateFunctionPositions will be [1, 2].
- protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = {
- val positions = new ArrayBuffer[Int]()
- var i = 0
- while (i < aggregateFunctions.length) {
- aggregateFunctions(i) match {
- case agg: AlgebraicAggregate =>
- case _ => positions += i
- }
- i += 1
- }
- positions.toArray
- }
-
- // All non-algebraic aggregate functions.
- protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
- nonAlgebraicAggregateFunctionPositions.map(aggregateFunctions)
-
- // This is used to project expressions for the grouping expressions.
- protected val groupGenerator =
- newMutableProjection(groupingExpressions, inputAttributes)()
-
- // The underlying buffer shared by all aggregate functions.
- protected val buffer: MutableRow = {
- // The number of elements of the underlying buffer of this operator.
- // All aggregate functions are sharing this underlying buffer and they find their
- // buffer values through bufferOffset.
- // var size = 0
- // var i = 0
- // while (i < aggregateFunctions.length) {
- // size += aggregateFunctions(i).bufferSchema.length
- // i += 1
- // }
- new GenericMutableRow(aggregateFunctions.map(_.bufferSchema.length).sum)
- }
-
- protected val joinedRow = new JoinedRow
-
- // This projection is used to initialize buffer values for all AlgebraicAggregates.
- protected val algebraicInitialProjection = {
- val initExpressions = aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.initialValues
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
- }
-
- newMutableProjection(initExpressions, Nil)().target(buffer)
- }
-
- ///////////////////////////////////////////////////////////////////////////
- // Mutable states
- ///////////////////////////////////////////////////////////////////////////
-
- // The partition key of the current partition.
- protected var currentGroupingKey: InternalRow = _
- // The partition key of next partition.
- protected var nextGroupingKey: InternalRow = _
- // The first row of next partition.
- protected var firstRowInNextGroup: InternalRow = _
- // Indicates if we has new group of rows to process.
- protected var hasNewGroup: Boolean = true
-
- /** Initializes buffer values for all aggregate functions. */
- protected def initializeBuffer(): Unit = {
- algebraicInitialProjection(EmptyRow)
- var i = 0
- while (i < nonAlgebraicAggregateFunctions.length) {
- nonAlgebraicAggregateFunctions(i).initialize(buffer)
- i += 1
- }
- }
-
- protected def initialize(): Unit = {
- if (inputIter.hasNext) {
- initializeBuffer()
- val currentRow = inputIter.next().copy()
- // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
- // we are making a copy at here.
- nextGroupingKey = groupGenerator(currentRow).copy()
- firstRowInNextGroup = currentRow
- } else {
- // This iter is an empty one.
- hasNewGroup = false
- }
- }
-
- ///////////////////////////////////////////////////////////////////////////
- // Private methods
- ///////////////////////////////////////////////////////////////////////////
-
- /** Processes rows in the current group. It will stop when it find a new group. */
- private def processCurrentGroup(): Unit = {
- currentGroupingKey = nextGroupingKey
- // Now, we will start to find all rows belonging to this group.
- // We create a variable to track if we see the next group.
- var findNextPartition = false
- // firstRowInNextGroup is the first row of this group. We first process it.
- processRow(firstRowInNextGroup)
- // The search will stop when we see the next group or there is no
- // input row left in the iter.
- while (inputIter.hasNext && !findNextPartition) {
- val currentRow = inputIter.next()
- // Get the grouping key based on the grouping expressions.
- // For the below compare method, we do not need to make a copy of groupingKey.
- val groupingKey = groupGenerator(currentRow)
- // Check if the current row belongs the current input row.
- if (currentGroupingKey == groupingKey) {
- processRow(currentRow)
- } else {
- // We find a new group.
- findNextPartition = true
- nextGroupingKey = groupingKey.copy()
- firstRowInNextGroup = currentRow.copy()
- }
- }
- // We have not seen a new group. It means that there is no new row in the input
- // iter. The current group is the last group of the iter.
- if (!findNextPartition) {
- hasNewGroup = false
- }
- }
-
- ///////////////////////////////////////////////////////////////////////////
- // Public methods
- ///////////////////////////////////////////////////////////////////////////
-
- override final def hasNext: Boolean = hasNewGroup
-
- override final def next(): InternalRow = {
- if (hasNext) {
- // Process the current group.
- processCurrentGroup()
- // Generate output row for the current group.
- val outputRow = generateOutput()
- // Initilize buffer values for the next group.
- initializeBuffer()
-
- outputRow
- } else {
- // no more result
- throw new NoSuchElementException
- }
- }
-
- ///////////////////////////////////////////////////////////////////////////
- // Methods that need to be implemented
- ///////////////////////////////////////////////////////////////////////////
-
- /** The initial input buffer offset for `inputBufferOffset` of an [[AggregateFunction2]]. */
- protected def initialInputBufferOffset: Int
-
- /** The function used to process an input row. */
- protected def processRow(row: InternalRow): Unit
-
- /** The function used to generate the result row. */
- protected def generateOutput(): InternalRow
-
- ///////////////////////////////////////////////////////////////////////////
- // Initialize this iterator
- ///////////////////////////////////////////////////////////////////////////
-
- initialize()
-}
-
-/**
- * An iterator used to do partial aggregations (for those aggregate functions with mode Partial).
- * It assumes that input rows are already grouped by values of `groupingExpressions`.
- * The format of its output rows is:
- * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
- */
-class PartialSortAggregationIterator(
- groupingExpressions: Seq[NamedExpression],
- aggregateExpressions: Seq[AggregateExpression2],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow])
- extends SortAggregationIterator(
- groupingExpressions,
- aggregateExpressions,
- newMutableProjection,
- inputAttributes,
- inputIter) {
-
- // This projection is used to update buffer values for all AlgebraicAggregates.
- private val algebraicUpdateProjection = {
- val bufferSchema = aggregateFunctions.flatMap(_.bufferAttributes)
- val updateExpressions = aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.updateExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
- }
- newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
- }
-
- override protected def initialInputBufferOffset: Int = 0
-
- override protected def processRow(row: InternalRow): Unit = {
- // Process all algebraic aggregate functions.
- algebraicUpdateProjection(joinedRow(buffer, row))
- // Process all non-algebraic aggregate functions.
- var i = 0
- while (i < nonAlgebraicAggregateFunctions.length) {
- nonAlgebraicAggregateFunctions(i).update(buffer, row)
- i += 1
- }
- }
-
- override protected def generateOutput(): InternalRow = {
- // We just output the grouping expressions and the underlying buffer.
- joinedRow(currentGroupingKey, buffer).copy()
- }
-}
-
-/**
- * An iterator used to do partial merge aggregations (for those aggregate functions with mode
- * PartialMerge). It assumes that input rows are already grouped by values of
- * `groupingExpressions`.
- * The format of its input rows is:
- * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
- *
- * The format of its internal buffer is:
- * |aggregationBuffer1|...|aggregationBufferN|
- *
- * The format of its output rows is:
- * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
- */
-class PartialMergeSortAggregationIterator(
- groupingExpressions: Seq[NamedExpression],
- aggregateExpressions: Seq[AggregateExpression2],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow])
- extends SortAggregationIterator(
- groupingExpressions,
- aggregateExpressions,
- newMutableProjection,
- inputAttributes,
- inputIter) {
-
- // This projection is used to merge buffer values for all AlgebraicAggregates.
- private val algebraicMergeProjection = {
- val mergeInputSchema =
- aggregateFunctions.flatMap(_.bufferAttributes) ++
- groupingExpressions.map(_.toAttribute) ++
- aggregateFunctions.flatMap(_.cloneBufferAttributes)
- val mergeExpressions = aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.mergeExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
- }
-
- newMutableProjection(mergeExpressions, mergeInputSchema)()
- }
-
- override protected def initialInputBufferOffset: Int = groupingExpressions.length
-
- override protected def processRow(row: InternalRow): Unit = {
- // Process all algebraic aggregate functions.
- algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
- // Process all non-algebraic aggregate functions.
- var i = 0
- while (i < nonAlgebraicAggregateFunctions.length) {
- nonAlgebraicAggregateFunctions(i).merge(buffer, row)
- i += 1
- }
- }
-
- override protected def generateOutput(): InternalRow = {
- // We output grouping expressions and aggregation buffers.
- joinedRow(currentGroupingKey, buffer).copy()
- }
-}
-
-/**
- * An iterator used to do final aggregations (for those aggregate functions with mode
- * Final). It assumes that input rows are already grouped by values of
- * `groupingExpressions`.
- * The format of its input rows is:
- * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
- *
- * The format of its internal buffer is:
- * |aggregationBuffer1|...|aggregationBufferN|
- *
- * The format of its output rows is represented by the schema of `resultExpressions`.
- */
-class FinalSortAggregationIterator(
- groupingExpressions: Seq[NamedExpression],
- aggregateExpressions: Seq[AggregateExpression2],
- aggregateAttributes: Seq[Attribute],
- resultExpressions: Seq[NamedExpression],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow])
- extends SortAggregationIterator(
- groupingExpressions,
- aggregateExpressions,
- newMutableProjection,
- inputAttributes,
- inputIter) {
-
- // The result of aggregate functions.
- private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length)
-
- // The projection used to generate the output rows of this operator.
- // This is only used when we are generating final results of aggregate functions.
- private val resultProjection =
- newMutableProjection(
- resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)()
-
- // This projection is used to merge buffer values for all AlgebraicAggregates.
- private val algebraicMergeProjection = {
- val mergeInputSchema =
- aggregateFunctions.flatMap(_.bufferAttributes) ++
- groupingExpressions.map(_.toAttribute) ++
- aggregateFunctions.flatMap(_.cloneBufferAttributes)
- val mergeExpressions = aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.mergeExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
- }
-
- newMutableProjection(mergeExpressions, mergeInputSchema)()
- }
-
- // This projection is used to evaluate all AlgebraicAggregates.
- private val algebraicEvalProjection = {
- val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes)
- val evalExpressions = aggregateFunctions.map {
- case ae: AlgebraicAggregate => ae.evaluateExpression
- case agg: AggregateFunction2 => NoOp
- }
-
- newMutableProjection(evalExpressions, bufferSchemata)()
- }
-
- override protected def initialInputBufferOffset: Int = groupingExpressions.length
-
- override def initialize(): Unit = {
- if (inputIter.hasNext) {
- initializeBuffer()
- val currentRow = inputIter.next().copy()
- // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
- // we are making a copy at here.
- nextGroupingKey = groupGenerator(currentRow).copy()
- firstRowInNextGroup = currentRow
- } else {
- if (groupingExpressions.isEmpty) {
- // If there is no grouping expression, we need to generate a single row as the output.
- initializeBuffer()
- // Right now, the buffer only contains initial buffer values. Because
- // merging two buffers with initial values will generate a row that
- // still store initial values. We set the currentRow as the copy of the current buffer.
- // Because input aggregation buffer has initialInputBufferOffset extra values at the
- // beginning, we create a dummy row for this part.
- val currentRow =
- joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy()
- nextGroupingKey = groupGenerator(currentRow).copy()
- firstRowInNextGroup = currentRow
- } else {
- // This iter is an empty one.
- hasNewGroup = false
- }
- }
- }
-
- override protected def processRow(row: InternalRow): Unit = {
- // Process all algebraic aggregate functions.
- algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
- // Process all non-algebraic aggregate functions.
- var i = 0
- while (i < nonAlgebraicAggregateFunctions.length) {
- nonAlgebraicAggregateFunctions(i).merge(buffer, row)
- i += 1
- }
- }
-
- override protected def generateOutput(): InternalRow = {
- // Generate results for all algebraic aggregate functions.
- algebraicEvalProjection.target(aggregateResult)(buffer)
- // Generate results for all non-algebraic aggregate functions.
- var i = 0
- while (i < nonAlgebraicAggregateFunctions.length) {
- aggregateResult.update(
- nonAlgebraicAggregateFunctionPositions(i),
- nonAlgebraicAggregateFunctions(i).eval(buffer))
- i += 1
- }
- resultProjection(joinedRow(currentGroupingKey, aggregateResult))
- }
-}
-
-/**
- * An iterator used to do both final aggregations (for those aggregate functions with mode
- * Final) and complete aggregations (for those aggregate functions with mode Complete).
- * It assumes that input rows are already grouped by values of `groupingExpressions`.
- * The format of its input rows is:
- * |groupingExpr1|...|groupingExprN|col1|...|colM|aggregationBuffer1|...|aggregationBufferN|
- * col1 to colM are columns used by aggregate functions with Complete mode.
- * aggregationBuffer1 to aggregationBufferN are buffers used by aggregate functions with
- * Final mode.
- *
- * The format of its internal buffer is:
- * |aggregationBuffer1|...|aggregationBuffer(N+M)|
- * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with
- * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode
- * Complete.
- *
- * The format of its output rows is represented by the schema of `resultExpressions`.
- */
-class FinalAndCompleteSortAggregationIterator(
- override protected val initialInputBufferOffset: Int,
- groupingExpressions: Seq[NamedExpression],
- finalAggregateExpressions: Seq[AggregateExpression2],
- finalAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
- completeAggregateAttributes: Seq[Attribute],
- resultExpressions: Seq[NamedExpression],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow])
- extends SortAggregationIterator(
- groupingExpressions,
- // TODO: document the ordering
- finalAggregateExpressions ++ completeAggregateExpressions,
- newMutableProjection,
- inputAttributes,
- inputIter) {
-
- // The result of aggregate functions.
- private val aggregateResult: MutableRow =
- new GenericMutableRow(completeAggregateAttributes.length + finalAggregateAttributes.length)
-
- // The projection used to generate the output rows of this operator.
- // This is only used when we are generating final results of aggregate functions.
- private val resultProjection = {
- val inputSchema =
- groupingExpressions.map(_.toAttribute) ++
- finalAggregateAttributes ++
- completeAggregateAttributes
- newMutableProjection(resultExpressions, inputSchema)()
- }
-
- // All aggregate functions with mode Final.
- private val finalAggregateFunctions: Array[AggregateFunction2] = {
- val functions = new Array[AggregateFunction2](finalAggregateExpressions.length)
- var i = 0
- while (i < finalAggregateExpressions.length) {
- functions(i) = aggregateFunctions(i)
- i += 1
- }
- functions
- }
-
- // All non-algebraic aggregate functions with mode Final.
- private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
- finalAggregateFunctions.collect {
- case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
- }
-
- // All aggregate functions with mode Complete.
- private val completeAggregateFunctions: Array[AggregateFunction2] = {
- val functions = new Array[AggregateFunction2](completeAggregateExpressions.length)
- var i = 0
- while (i < completeAggregateExpressions.length) {
- functions(i) = aggregateFunctions(finalAggregateFunctions.length + i)
- i += 1
- }
- functions
- }
-
- // All non-algebraic aggregate functions with mode Complete.
- private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
- completeAggregateFunctions.collect {
- case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
- }
-
- // This projection is used to merge buffer values for all AlgebraicAggregates with mode
- // Final.
- private val finalAlgebraicMergeProjection = {
- // The first initialInputBufferOffset values of the input aggregation buffer is
- // for grouping expressions and distinct columns.
- val groupingAttributesAndDistinctColumns = inputAttributes.take(initialInputBufferOffset)
-
- val completeOffsetExpressions =
- Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
-
- val mergeInputSchema =
- finalAggregateFunctions.flatMap(_.bufferAttributes) ++
- completeAggregateFunctions.flatMap(_.bufferAttributes) ++
- groupingAttributesAndDistinctColumns ++
- finalAggregateFunctions.flatMap(_.cloneBufferAttributes)
- val mergeExpressions =
- finalAggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.mergeExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
- } ++ completeOffsetExpressions
- newMutableProjection(mergeExpressions, mergeInputSchema)()
- }
-
- // This projection is used to update buffer values for all AlgebraicAggregates with mode
- // Complete.
- private val completeAlgebraicUpdateProjection = {
- // We do not touch buffer values of aggregate functions with the Final mode.
- val finalOffsetExpressions =
- Seq.fill(finalAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
-
- val bufferSchema =
- finalAggregateFunctions.flatMap(_.bufferAttributes) ++
- completeAggregateFunctions.flatMap(_.bufferAttributes)
- val updateExpressions =
- finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.updateExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
- }
- newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
- }
-
- // This projection is used to evaluate all AlgebraicAggregates.
- private val algebraicEvalProjection = {
- val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes)
- val evalExpressions = aggregateFunctions.map {
- case ae: AlgebraicAggregate => ae.evaluateExpression
- case agg: AggregateFunction2 => NoOp
- }
-
- newMutableProjection(evalExpressions, bufferSchemata)()
- }
-
- override def initialize(): Unit = {
- if (inputIter.hasNext) {
- initializeBuffer()
- val currentRow = inputIter.next().copy()
- // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
- // we are making a copy at here.
- nextGroupingKey = groupGenerator(currentRow).copy()
- firstRowInNextGroup = currentRow
- } else {
- if (groupingExpressions.isEmpty) {
- // If there is no grouping expression, we need to generate a single row as the output.
- initializeBuffer()
- // Right now, the buffer only contains initial buffer values. Because
- // merging two buffers with initial values will generate a row that
- // still store initial values. We set the currentRow as the copy of the current buffer.
- // Because input aggregation buffer has initialInputBufferOffset extra values at the
- // beginning, we create a dummy row for this part.
- val currentRow =
- joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy()
- nextGroupingKey = groupGenerator(currentRow).copy()
- firstRowInNextGroup = currentRow
- } else {
- // This iter is an empty one.
- hasNewGroup = false
- }
- }
- }
-
- override protected def processRow(row: InternalRow): Unit = {
- val input = joinedRow(buffer, row)
- // For all aggregate functions with mode Complete, update buffers.
- completeAlgebraicUpdateProjection(input)
- var i = 0
- while (i < completeNonAlgebraicAggregateFunctions.length) {
- completeNonAlgebraicAggregateFunctions(i).update(buffer, row)
- i += 1
- }
-
- // For all aggregate functions with mode Final, merge buffers.
- finalAlgebraicMergeProjection.target(buffer)(input)
- i = 0
- while (i < finalNonAlgebraicAggregateFunctions.length) {
- finalNonAlgebraicAggregateFunctions(i).merge(buffer, row)
- i += 1
- }
- }
-
- override protected def generateOutput(): InternalRow = {
- // Generate results for all algebraic aggregate functions.
- algebraicEvalProjection.target(aggregateResult)(buffer)
- // Generate results for all non-algebraic aggregate functions.
- var i = 0
- while (i < nonAlgebraicAggregateFunctions.length) {
- aggregateResult.update(
- nonAlgebraicAggregateFunctionPositions(i),
- nonAlgebraicAggregateFunctions(i).eval(buffer))
- i += 1
- }
-
- resultProjection(joinedRow(currentGroupingKey, aggregateResult))
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index cc54319171..5fafc916bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -24,7 +24,154 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjecti
import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
-import org.apache.spark.sql.types.{Metadata, StructField, StructType, DataType}
+import org.apache.spark.sql.types._
+
+/**
+ * A helper trait used to create specialized setter and getter for types supported by
+ * [[org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap]]'s buffer.
+ * (see UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema).
+ */
+sealed trait BufferSetterGetterUtils {
+
+ def createGetters(schema: StructType): Array[(InternalRow, Int) => Any] = {
+ val dataTypes = schema.fields.map(_.dataType)
+ val getters = new Array[(InternalRow, Int) => Any](dataTypes.length)
+
+ var i = 0
+ while (i < getters.length) {
+ getters(i) = dataTypes(i) match {
+ case BooleanType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal)
+
+ case ByteType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getByte(ordinal)
+
+ case ShortType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getShort(ordinal)
+
+ case IntegerType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getInt(ordinal)
+
+ case LongType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getLong(ordinal)
+
+ case FloatType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getFloat(ordinal)
+
+ case DoubleType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getDouble(ordinal)
+
+ case dt: DecimalType =>
+ val precision = dt.precision
+ val scale = dt.scale
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale)
+
+ case other =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.get(ordinal, other)
+ }
+
+ i += 1
+ }
+
+ getters
+ }
+
+ def createSetters(schema: StructType): Array[((MutableRow, Int, Any) => Unit)] = {
+ val dataTypes = schema.fields.map(_.dataType)
+ val setters = new Array[(MutableRow, Int, Any) => Unit](dataTypes.length)
+
+ var i = 0
+ while (i < setters.length) {
+ setters(i) = dataTypes(i) match {
+ case b: BooleanType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setBoolean(ordinal, value.asInstanceOf[Boolean])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case ByteType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setByte(ordinal, value.asInstanceOf[Byte])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case ShortType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setShort(ordinal, value.asInstanceOf[Short])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case IntegerType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setInt(ordinal, value.asInstanceOf[Int])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case LongType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setLong(ordinal, value.asInstanceOf[Long])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case FloatType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setFloat(ordinal, value.asInstanceOf[Float])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case DoubleType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setDouble(ordinal, value.asInstanceOf[Double])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case dt: DecimalType =>
+ val precision = dt.precision
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision)
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case other =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.update(ordinal, value)
+ } else {
+ row.setNullAt(ordinal)
+ }
+ }
+
+ i += 1
+ }
+
+ setters
+ }
+}
/**
* A Mutable [[Row]] representing an mutable aggregation buffer.
@@ -35,7 +182,7 @@ private[sql] class MutableAggregationBufferImpl (
toScalaConverters: Array[Any => Any],
bufferOffset: Int,
var underlyingBuffer: MutableRow)
- extends MutableAggregationBuffer {
+ extends MutableAggregationBuffer with BufferSetterGetterUtils {
private[this] val offsets: Array[Int] = {
val newOffsets = new Array[Int](length)
@@ -47,6 +194,10 @@ private[sql] class MutableAggregationBufferImpl (
newOffsets
}
+ private[this] val bufferValueGetters = createGetters(schema)
+
+ private[this] val bufferValueSetters = createSetters(schema)
+
override def length: Int = toCatalystConverters.length
override def get(i: Int): Any = {
@@ -54,7 +205,7 @@ private[sql] class MutableAggregationBufferImpl (
throw new IllegalArgumentException(
s"Could not access ${i}th value in this buffer because it only has $length values.")
}
- toScalaConverters(i)(underlyingBuffer.get(offsets(i), schema(i).dataType))
+ toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i)))
}
def update(i: Int, value: Any): Unit = {
@@ -62,7 +213,15 @@ private[sql] class MutableAggregationBufferImpl (
throw new IllegalArgumentException(
s"Could not update ${i}th value in this buffer because it only has $length values.")
}
- underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value))
+
+ bufferValueSetters(i)(underlyingBuffer, offsets(i), toCatalystConverters(i)(value))
+ }
+
+ // Because get method call specialized getter based on the schema, we cannot use the
+ // default implementation of the isNullAt (which is get(i) == null).
+ // We have to override it to call isNullAt of the underlyingBuffer.
+ override def isNullAt(i: Int): Boolean = {
+ underlyingBuffer.isNullAt(offsets(i))
}
override def copy(): MutableAggregationBufferImpl = {
@@ -84,7 +243,7 @@ private[sql] class InputAggregationBuffer private[sql] (
toScalaConverters: Array[Any => Any],
bufferOffset: Int,
var underlyingInputBuffer: InternalRow)
- extends Row {
+ extends Row with BufferSetterGetterUtils {
private[this] val offsets: Array[Int] = {
val newOffsets = new Array[Int](length)
@@ -96,6 +255,10 @@ private[sql] class InputAggregationBuffer private[sql] (
newOffsets
}
+ private[this] val bufferValueGetters = createGetters(schema)
+
+ def getBufferOffset: Int = bufferOffset
+
override def length: Int = toCatalystConverters.length
override def get(i: Int): Any = {
@@ -103,8 +266,14 @@ private[sql] class InputAggregationBuffer private[sql] (
throw new IllegalArgumentException(
s"Could not access ${i}th value in this buffer because it only has $length values.")
}
- // TODO: Use buffer schema to avoid using generic getter.
- toScalaConverters(i)(underlyingInputBuffer.get(offsets(i), schema(i).dataType))
+ toScalaConverters(i)(bufferValueGetters(i)(underlyingInputBuffer, offsets(i)))
+ }
+
+ // Because get method call specialized getter based on the schema, we cannot use the
+ // default implementation of the isNullAt (which is get(i) == null).
+ // We have to override it to call isNullAt of the underlyingInputBuffer.
+ override def isNullAt(i: Int): Boolean = {
+ underlyingInputBuffer.isNullAt(offsets(i))
}
override def copy(): InputAggregationBuffer = {
@@ -147,7 +316,7 @@ private[sql] case class ScalaUDAF(
override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
- val childrenSchema: StructType = {
+ private[this] val childrenSchema: StructType = {
val inputFields = children.zipWithIndex.map {
case (child, index) =>
StructField(s"input$index", child.dataType, child.nullable, Metadata.empty)
@@ -155,7 +324,7 @@ private[sql] case class ScalaUDAF(
StructType(inputFields)
}
- lazy val inputProjection = {
+ private lazy val inputProjection = {
val inputAttributes = childrenSchema.toAttributes
log.debug(
s"Creating MutableProj: $children, inputSchema: $inputAttributes.")
@@ -168,40 +337,68 @@ private[sql] case class ScalaUDAF(
}
}
- val inputToScalaConverters: Any => Any =
+ private[this] val inputToScalaConverters: Any => Any =
CatalystTypeConverters.createToScalaConverter(childrenSchema)
- val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
- CatalystTypeConverters.createToCatalystConverter(field.dataType)
+ private[this] val bufferValuesToCatalystConverters: Array[Any => Any] = {
+ bufferSchema.fields.map { field =>
+ CatalystTypeConverters.createToCatalystConverter(field.dataType)
+ }
}
- val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
- CatalystTypeConverters.createToScalaConverter(field.dataType)
+ private[this] val bufferValuesToScalaConverters: Array[Any => Any] = {
+ bufferSchema.fields.map { field =>
+ CatalystTypeConverters.createToScalaConverter(field.dataType)
+ }
}
- lazy val inputAggregateBuffer: InputAggregationBuffer =
- new InputAggregationBuffer(
- bufferSchema,
- bufferValuesToCatalystConverters,
- bufferValuesToScalaConverters,
- inputBufferOffset,
- null)
-
- lazy val mutableAggregateBuffer: MutableAggregationBufferImpl =
- new MutableAggregationBufferImpl(
- bufferSchema,
- bufferValuesToCatalystConverters,
- bufferValuesToScalaConverters,
- mutableBufferOffset,
- null)
+ // This buffer is only used at executor side.
+ private[this] var inputAggregateBuffer: InputAggregationBuffer = null
+
+ // This buffer is only used at executor side.
+ private[this] var mutableAggregateBuffer: MutableAggregationBufferImpl = null
+
+ // This buffer is only used at executor side.
+ private[this] var evalAggregateBuffer: InputAggregationBuffer = null
+
+ /**
+ * Sets the inputBufferOffset to newInputBufferOffset and then create a new instance of
+ * `inputAggregateBuffer` based on this new inputBufferOffset.
+ */
+ override def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = {
+ super.withNewInputBufferOffset(newInputBufferOffset)
+ // inputBufferOffset has been updated.
+ inputAggregateBuffer =
+ new InputAggregationBuffer(
+ bufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ inputBufferOffset,
+ null)
+ }
- lazy val evalAggregateBuffer: InputAggregationBuffer =
- new InputAggregationBuffer(
- bufferSchema,
- bufferValuesToCatalystConverters,
- bufferValuesToScalaConverters,
- mutableBufferOffset,
- null)
+ /**
+ * Sets the mutableBufferOffset to newMutableBufferOffset and then create a new instance of
+ * `mutableAggregateBuffer` and `evalAggregateBuffer` based on this new mutableBufferOffset.
+ */
+ override def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = {
+ super.withNewMutableBufferOffset(newMutableBufferOffset)
+ // mutableBufferOffset has been updated.
+ mutableAggregateBuffer =
+ new MutableAggregationBufferImpl(
+ bufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ mutableBufferOffset,
+ null)
+ evalAggregateBuffer =
+ new InputAggregationBuffer(
+ bufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ mutableBufferOffset,
+ null)
+ }
override def initialize(buffer: MutableRow): Unit = {
mutableAggregateBuffer.underlyingBuffer = buffer
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 03635baae4..960be08f84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -17,13 +17,9 @@
package org.apache.spark.sql.execution.aggregate
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
@@ -52,13 +48,16 @@ object Utils {
agg.aggregateFunction.bufferAttributes
}
val partialAggregate =
- Aggregate2Sort(
- None: Option[Seq[Expression]],
- namedGroupingExpressions.map(_._2),
- partialAggregateExpressions,
- partialAggregateAttributes,
- namedGroupingAttributes ++ partialAggregateAttributes,
- child)
+ Aggregate(
+ requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+ groupingExpressions = namedGroupingExpressions.map(_._2),
+ nonCompleteAggregateExpressions = partialAggregateExpressions,
+ nonCompleteAggregateAttributes = partialAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = 0,
+ resultExpressions = namedGroupingAttributes ++ partialAggregateAttributes,
+ child = child)
// 2. Create an Aggregate Operator for final aggregations.
val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
@@ -78,13 +77,17 @@ object Utils {
}.getOrElse(expression)
}.asInstanceOf[NamedExpression]
}
- val finalAggregate = Aggregate2Sort(
- Some(namedGroupingAttributes),
- namedGroupingAttributes,
- finalAggregateExpressions,
- finalAggregateAttributes,
- rewrittenResultExpressions,
- partialAggregate)
+ val finalAggregate =
+ Aggregate(
+ requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+ groupingExpressions = namedGroupingAttributes,
+ nonCompleteAggregateExpressions = finalAggregateExpressions,
+ nonCompleteAggregateAttributes = finalAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = namedGroupingAttributes.length,
+ resultExpressions = rewrittenResultExpressions,
+ child = partialAggregate)
finalAggregate :: Nil
}
@@ -133,14 +136,21 @@ object Utils {
val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
agg.aggregateFunction.bufferAttributes
}
+ val partialAggregateGroupingExpressions =
+ (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2)
+ val partialAggregateResult =
+ namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes
val partialAggregate =
- Aggregate2Sort(
- None: Option[Seq[Expression]],
- (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2),
- partialAggregateExpressions,
- partialAggregateAttributes,
- namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes,
- child)
+ Aggregate(
+ requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+ groupingExpressions = partialAggregateGroupingExpressions,
+ nonCompleteAggregateExpressions = partialAggregateExpressions,
+ nonCompleteAggregateAttributes = partialAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = 0,
+ resultExpressions = partialAggregateResult,
+ child = child)
// 2. Create an Aggregate Operator for partial merge aggregations.
val partialMergeAggregateExpressions = functionsWithoutDistinct.map {
@@ -151,14 +161,19 @@ object Utils {
partialMergeAggregateExpressions.flatMap { agg =>
agg.aggregateFunction.bufferAttributes
}
+ val partialMergeAggregateResult =
+ namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes
val partialMergeAggregate =
- Aggregate2Sort(
- Some(namedGroupingAttributes),
- namedGroupingAttributes ++ distinctColumnAttributes,
- partialMergeAggregateExpressions,
- partialMergeAggregateAttributes,
- namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes,
- partialAggregate)
+ Aggregate(
+ requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+ groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
+ nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+ nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+ resultExpressions = partialMergeAggregateResult,
+ child = partialAggregate)
// 3. Create an Aggregate Operator for partial merge aggregations.
val finalAggregateExpressions = functionsWithoutDistinct.map {
@@ -199,15 +214,17 @@ object Utils {
}.getOrElse(expression)
}.asInstanceOf[NamedExpression]
}
- val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort(
- namedGroupingAttributes ++ distinctColumnAttributes,
- namedGroupingAttributes,
- finalAggregateExpressions,
- finalAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
- rewrittenResultExpressions,
- partialMergeAggregate)
+ val finalAndCompleteAggregate =
+ Aggregate(
+ requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+ groupingExpressions = namedGroupingAttributes,
+ nonCompleteAggregateExpressions = finalAggregateExpressions,
+ nonCompleteAggregateAttributes = finalAggregateAttributes,
+ completeAggregateExpressions = completeAggregateExpressions,
+ completeAggregateAttributes = completeAggregateAttributes,
+ initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+ resultExpressions = rewrittenResultExpressions,
+ child = partialMergeAggregate)
finalAndCompleteAggregate :: Nil
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 2294a670c7..5a1b000e89 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -220,7 +220,6 @@ case class TakeOrderedAndProject(
override def outputOrdering: Seq[SortOrder] = sortOrder
}
-
/**
* :: DeveloperApi ::
* Return a new RDD that has exactly `numPartitions` partitions.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 51fe9d9d98..bbadc202a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,14 +17,14 @@
package org.apache.spark.sql
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
-import org.scalatest.BeforeAndAfterAll
-
import java.sql.Timestamp
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.errors.DialectException
-import org.apache.spark.sql.execution.aggregate.Aggregate2Sort
+import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
@@ -273,7 +273,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
var hasGeneratedAgg = false
df.queryExecution.executedPlan.foreach {
case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
- case newAggregate: Aggregate2Sort => hasGeneratedAgg = true
+ case newAggregate: aggregate.Aggregate => hasGeneratedAgg = true
case _ =>
}
if (!hasGeneratedAgg) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
index 54f82f89ed..7978ed57a9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
@@ -138,7 +138,14 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
s"Expected $expectedSerializerClass as the serializer of Exchange. " +
s"However, the serializer was not set."
val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage))
- assert(serializer.getClass === expectedSerializerClass)
+ val isExpectedSerializer =
+ serializer.getClass == expectedSerializerClass ||
+ serializer.getClass == classOf[UnsafeRowSerializer]
+ val wrongSerializerErrorMessage =
+ s"Expected ${expectedSerializerClass.getCanonicalName} or " +
+ s"${classOf[UnsafeRowSerializer].getCanonicalName}. But " +
+ s"${serializer.getClass.getCanonicalName} is used."
+ assert(isExpectedSerializer, wrongSerializerErrorMessage)
case _ => // Ignore other nodes.
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 0375eb79ad..6f0db27775 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -17,15 +17,15 @@
package org.apache.spark.sql.hive.execution
-import org.apache.spark.sql.execution.aggregate.Aggregate2Sort
+import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row}
import org.scalatest.BeforeAndAfterAll
import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
-class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
+abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
override val sqlContext = TestHive
import sqlContext.implicits._
@@ -34,7 +34,7 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf
override def beforeAll(): Unit = {
originalUseAggregate2 = sqlContext.conf.useSqlAggregate2
- sqlContext.sql("set spark.sql.useAggregate2=true")
+ sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true")
val data1 = Seq[(Integer, Integer)](
(1, 10),
(null, -60),
@@ -81,7 +81,7 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf
sqlContext.sql("DROP TABLE IF EXISTS agg1")
sqlContext.sql("DROP TABLE IF EXISTS agg2")
sqlContext.dropTempTable("emptyTable")
- sqlContext.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2")
+ sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString)
}
test("empty table") {
@@ -454,54 +454,86 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf
}
test("error handling") {
- sqlContext.sql(s"set spark.sql.useAggregate2=false")
- var errorMessage = intercept[AnalysisException] {
- sqlContext.sql(
- """
- |SELECT
- | key,
- | sum(value + 1.5 * key),
- | mydoublesum(value),
- | mydoubleavg(value)
- |FROM agg1
- |GROUP BY key
- """.stripMargin).collect()
- }.getMessage
- assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
+ withSQLConf("spark.sql.useAggregate2" -> "false") {
+ val errorMessage = intercept[AnalysisException] {
+ sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | sum(value + 1.5 * key),
+ | mydoublesum(value),
+ | mydoubleavg(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin).collect()
+ }.getMessage
+ assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
+ }
// TODO: once we support Hive UDAF in the new interface,
// we can remove the following two tests.
- sqlContext.sql(s"set spark.sql.useAggregate2=true")
- errorMessage = intercept[AnalysisException] {
- sqlContext.sql(
+ withSQLConf("spark.sql.useAggregate2" -> "true") {
+ val errorMessage = intercept[AnalysisException] {
+ sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | mydoublesum(value + 1.5 * key),
+ | stddev_samp(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin).collect()
+ }.getMessage
+ assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
+
+ // This will fall back to the old aggregate
+ val newAggregateOperators = sqlContext.sql(
"""
|SELECT
| key,
- | mydoublesum(value + 1.5 * key),
+ | sum(value + 1.5 * key),
| stddev_samp(value)
|FROM agg1
|GROUP BY key
- """.stripMargin).collect()
- }.getMessage
- assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
-
- // This will fall back to the old aggregate
- val newAggregateOperators = sqlContext.sql(
- """
- |SELECT
- | key,
- | sum(value + 1.5 * key),
- | stddev_samp(value)
- |FROM agg1
- |GROUP BY key
- """.stripMargin).queryExecution.executedPlan.collect {
- case agg: Aggregate2Sort => agg
+ """.stripMargin).queryExecution.executedPlan.collect {
+ case agg: aggregate.Aggregate => agg
+ }
+ val message =
+ "We should fallback to the old aggregation code path if " +
+ "there is any aggregate function that cannot be converted to the new interface."
+ assert(newAggregateOperators.isEmpty, message)
}
- val message =
- "We should fallback to the old aggregation code path if there is any aggregate function " +
- "that cannot be converted to the new interface."
- assert(newAggregateOperators.isEmpty, message)
+ }
+}
+
+class SortBasedAggregationQuerySuite extends AggregationQuerySuite {
- sqlContext.sql(s"set spark.sql.useAggregate2=true")
+ var originalUnsafeEnabled: Boolean = _
+
+ override def beforeAll(): Unit = {
+ originalUnsafeEnabled = sqlContext.conf.unsafeEnabled
+ sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "false")
+ super.beforeAll()
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString)
+ }
+}
+
+class TungstenAggregationQuerySuite extends AggregationQuerySuite {
+
+ var originalUnsafeEnabled: Boolean = _
+
+ override def beforeAll(): Unit = {
+ originalUnsafeEnabled = sqlContext.conf.unsafeEnabled
+ sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true")
+ super.beforeAll()
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString)
}
}