aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-08-06 15:04:44 -0700
committerReynold Xin <rxin@databricks.com>2015-08-06 15:04:44 -0700
commit3504bf3aa9f7b75c0985f04ce2944833d8c5b5bd (patch)
tree8968e5ec73c6d139a974e5874e0087d9eba9575e
parent346209097e88fe79015359e40b49c32cc0bdc439 (diff)
downloadspark-3504bf3aa9f7b75c0985f04ce2944833d8c5b5bd.tar.gz
spark-3504bf3aa9f7b75c0985f04ce2944833d8c5b5bd.tar.bz2
spark-3504bf3aa9f7b75c0985f04ce2944833d8c5b5bd.zip
[SPARK-9630] [SQL] Clean up new aggregate operators (SPARK-9240 follow up)
This is the followup of https://github.com/apache/spark/pull/7813. It renames `HybridUnsafeAggregationIterator` to `TungstenAggregationIterator` and makes it only work with `UnsafeRow`. Also, I add a `TungstenAggregate` that uses `TungstenAggregationIterator` and make `SortBasedAggregate` (renamed from `SortBasedAggregate`) only works with `SafeRow`. Author: Yin Huai <yhuai@databricks.com> Closes #7954 from yhuai/agg-followUp and squashes the following commits: 4d2f4fc [Yin Huai] Add comments and free map. 0d7ddb9 [Yin Huai] Add TungstenAggregationQueryWithControlledFallbackSuite to test fall back process. 91d69c2 [Yin Huai] Rename UnsafeHybridAggregationIterator to TungstenAggregateIteraotr and make it only work with UnsafeRow.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala182
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala103
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala102
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala667
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala372
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala260
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala104
12 files changed, 1192 insertions, 663 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index 88fb516e64..a73024d6ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -31,8 +31,11 @@ case class Average(child: Expression) extends AlgebraicAggregate {
override def dataType: DataType = resultType
// Expected input data type.
- // TODO: Once we remove the old code path, we can use our analyzer to cast NullType
- // to the default data type of the NumericType.
+ // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
+ // new version at planning time (after analysis phase). For now, NullType is added at here
+ // to make it resolved when we have cases like `select avg(null)`.
+ // We can use our analyzer to cast NullType to the default data type of the NumericType once
+ // we remove the old aggregate functions. Then, we will not need NullType at here.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
private val resultType = child.dataType match {
@@ -256,12 +259,19 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
override def dataType: DataType = resultType
// Expected input data type.
+ // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
+ // new version at planning time (after analysis phase). For now, NullType is added at here
+ // to make it resolved when we have cases like `select sum(null)`.
+ // We can use our analyzer to cast NullType to the default data type of the NumericType once
+ // we remove the old aggregate functions. Then, we will not need NullType at here.
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
private val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
+ // TODO: Remove this line once we remove the NullType from inputTypes.
+ case NullType => IntegerType
case _ => child.dataType
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index a730ffbb21..c5aaebe673 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -191,8 +191,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// aggregate function to the corresponding attribute of the function.
val aggregateFunctionMap = aggregateExpressions.map { agg =>
val aggregateFunction = agg.aggregateFunction
+ val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
(aggregateFunction, agg.isDistinct) ->
- Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
+ (aggregateFunction -> attribtue)
}.toMap
val (functionsWithDistinct, functionsWithoutDistinct) =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 16498da080..39f8f992a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution
-import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream}
+import java.io._
import java.nio.ByteBuffer
import scala.reflect.ClassTag
@@ -58,11 +58,26 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
*/
override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
+ // When `out` is backed by ChainedBufferOutputStream, we will get an
+ // UnsupportedOperationException when we call dOut.writeInt because it internally calls
+ // ChainedBufferOutputStream's write(b: Int), which is not supported.
+ // To workaround this issue, we create an array for sorting the int value.
+ // To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and
+ // run SparkSqlSerializer2SortMergeShuffleSuite.
+ private[this] var intBuffer: Array[Byte] = new Array[Byte](4)
private[this] val dOut: DataOutputStream = new DataOutputStream(out)
override def writeValue[T: ClassTag](value: T): SerializationStream = {
val row = value.asInstanceOf[UnsafeRow]
- dOut.writeInt(row.getSizeInBytes)
+ val size = row.getSizeInBytes
+ // This part is based on DataOutputStream's writeInt.
+ // It is for dOut.writeInt(row.getSizeInBytes).
+ intBuffer(0) = ((size >>> 24) & 0xFF).toByte
+ intBuffer(1) = ((size >>> 16) & 0xFF).toByte
+ intBuffer(2) = ((size >>> 8) & 0xFF).toByte
+ intBuffer(3) = ((size >>> 0) & 0xFF).toByte
+ dOut.write(intBuffer, 0, 4)
+
row.writeToStream(out, writeBuffer)
this
}
@@ -90,6 +105,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
override def close(): Unit = {
writeBuffer = null
+ intBuffer = null
dOut.writeInt(EOF)
dOut.close()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala
deleted file mode 100644
index cf568dc048..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala
+++ /dev/null
@@ -1,182 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
-import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode}
-import org.apache.spark.sql.types.StructType
-
-/**
- * An Aggregate Operator used to evaluate [[AggregateFunction2]]. Based on the data types
- * of the grouping expressions and aggregate functions, it determines if it uses
- * sort-based aggregation and hybrid (hash-based with sort-based as the fallback) to
- * process input rows.
- */
-case class Aggregate(
- requiredChildDistributionExpressions: Option[Seq[Expression]],
- groupingExpressions: Seq[NamedExpression],
- nonCompleteAggregateExpressions: Seq[AggregateExpression2],
- nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
- completeAggregateAttributes: Seq[Attribute],
- initialInputBufferOffset: Int,
- resultExpressions: Seq[NamedExpression],
- child: SparkPlan)
- extends UnaryNode {
-
- private[this] val allAggregateExpressions =
- nonCompleteAggregateExpressions ++ completeAggregateExpressions
-
- private[this] val hasNonAlgebricAggregateFunctions =
- !allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])
-
- // Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of
- // grouping key and aggregation buffer is supported; and (3) all
- // aggregate functions are algebraic.
- private[this] val supportsHybridIterator: Boolean = {
- val aggregationBufferSchema: StructType =
- StructType.fromAttributes(
- allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
- val groupKeySchema: StructType =
- StructType.fromAttributes(groupingExpressions.map(_.toAttribute))
-
- val schemaSupportsUnsafe: Boolean =
- UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
- UnsafeProjection.canSupport(groupKeySchema)
-
- // TODO: Use the hybrid iterator for non-algebric aggregate functions.
- sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && !hasNonAlgebricAggregateFunctions
- }
-
- // We need to use sorted input if we have grouping expressions, and
- // we cannot use the hybrid iterator or the hybrid is disabled.
- private[this] val requiresSortedInput: Boolean = {
- groupingExpressions.nonEmpty && !supportsHybridIterator
- }
-
- override def canProcessUnsafeRows: Boolean = !hasNonAlgebricAggregateFunctions
-
- // If result expressions' data types are all fixed length, we generate unsafe rows
- // (We have this requirement instead of check the result of UnsafeProjection.canSupport
- // is because we use a mutable projection to generate the result).
- override def outputsUnsafeRows: Boolean = {
- // resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength)
- // TODO: Supports generating UnsafeRows. We can just re-enable the line above and fix
- // any issue we get.
- false
- }
-
- override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
- override def requiredChildDistribution: List[Distribution] = {
- requiredChildDistributionExpressions match {
- case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
- case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
- case None => UnspecifiedDistribution :: Nil
- }
- }
-
- override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
- if (requiresSortedInput) {
- // TODO: We should not sort the input rows if they are just in reversed order.
- groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
- } else {
- Seq.fill(children.size)(Nil)
- }
- }
-
- override def outputOrdering: Seq[SortOrder] = {
- if (requiresSortedInput) {
- // It is possible that the child.outputOrdering starts with the required
- // ordering expressions (e.g. we require [a] as the sort expression and the
- // child's outputOrdering is [a, b]). We can only guarantee the output rows
- // are sorted by values of groupingExpressions.
- groupingExpressions.map(SortOrder(_, Ascending))
- } else {
- Nil
- }
- }
-
- protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
- child.execute().mapPartitions { iter =>
- // Because the constructor of an aggregation iterator will read at least the first row,
- // we need to get the value of iter.hasNext first.
- val hasInput = iter.hasNext
- val useHybridIterator =
- hasInput &&
- supportsHybridIterator &&
- groupingExpressions.nonEmpty
- if (useHybridIterator) {
- UnsafeHybridAggregationIterator.createFromInputIterator(
- groupingExpressions,
- nonCompleteAggregateExpressions,
- nonCompleteAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
- initialInputBufferOffset,
- resultExpressions,
- newMutableProjection _,
- child.output,
- iter,
- outputsUnsafeRows)
- } else {
- if (!hasInput && groupingExpressions.nonEmpty) {
- // This is a grouped aggregate and the input iterator is empty,
- // so return an empty iterator.
- Iterator[InternalRow]()
- } else {
- val outputIter = SortBasedAggregationIterator.createFromInputIterator(
- groupingExpressions,
- nonCompleteAggregateExpressions,
- nonCompleteAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
- initialInputBufferOffset,
- resultExpressions,
- newMutableProjection _ ,
- newProjection _,
- child.output,
- iter,
- outputsUnsafeRows)
- if (!hasInput && groupingExpressions.isEmpty) {
- // There is no input and there is no grouping expressions.
- // We need to output a single row as the output.
- Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
- } else {
- outputIter
- }
- }
- }
- }
- }
-
- override def simpleString: String = {
- val iterator = if (supportsHybridIterator && groupingExpressions.nonEmpty) {
- classOf[UnsafeHybridAggregationIterator].getSimpleName
- } else {
- classOf[SortBasedAggregationIterator].getSimpleName
- }
-
- s"""NewAggregate with $iterator ${groupingExpressions} ${allAggregateExpressions}"""
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
new file mode 100644
index 0000000000..ad428ad663
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode}
+import org.apache.spark.sql.types.StructType
+
+case class SortBasedAggregate(
+ requiredChildDistributionExpressions: Option[Seq[Expression]],
+ groupingExpressions: Seq[NamedExpression],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryNode {
+
+ override def outputsUnsafeRows: Boolean = false
+
+ override def canProcessUnsafeRows: Boolean = false
+
+ override def canProcessSafeRows: Boolean = true
+
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ override def requiredChildDistribution: List[Distribution] = {
+ requiredChildDistributionExpressions match {
+ case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+ case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+ case None => UnspecifiedDistribution :: Nil
+ }
+ }
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+ groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+ }
+
+ override def outputOrdering: Seq[SortOrder] = {
+ groupingExpressions.map(SortOrder(_, Ascending))
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ child.execute().mapPartitions { iter =>
+ // Because the constructor of an aggregation iterator will read at least the first row,
+ // we need to get the value of iter.hasNext first.
+ val hasInput = iter.hasNext
+ if (!hasInput && groupingExpressions.nonEmpty) {
+ // This is a grouped aggregate and the input iterator is empty,
+ // so return an empty iterator.
+ Iterator[InternalRow]()
+ } else {
+ val outputIter = SortBasedAggregationIterator.createFromInputIterator(
+ groupingExpressions,
+ nonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection _,
+ newProjection _,
+ child.output,
+ iter,
+ outputsUnsafeRows)
+ if (!hasInput && groupingExpressions.isEmpty) {
+ // There is no input and there is no grouping expressions.
+ // We need to output a single row as the output.
+ Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
+ } else {
+ outputIter
+ }
+ }
+ }
+ }
+
+ override def simpleString: String = {
+ val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
+ s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}"""
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index 40f6bff53d..67ebafde25 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -204,31 +204,5 @@ object SortBasedAggregationIterator {
newMutableProjection,
outputsUnsafeRows)
}
-
- def createFromKVIterator(
- groupingKeyAttributes: Seq[Attribute],
- valueAttributes: Seq[Attribute],
- inputKVIterator: KVIterator[InternalRow, InternalRow],
- nonCompleteAggregateExpressions: Seq[AggregateExpression2],
- nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
- completeAggregateAttributes: Seq[Attribute],
- initialInputBufferOffset: Int,
- resultExpressions: Seq[NamedExpression],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- outputsUnsafeRows: Boolean): SortBasedAggregationIterator = {
- new SortBasedAggregationIterator(
- groupingKeyAttributes,
- valueAttributes,
- inputKVIterator,
- nonCompleteAggregateExpressions,
- nonCompleteAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
- initialInputBufferOffset,
- resultExpressions,
- newMutableProjection,
- outputsUnsafeRows)
- }
// scalastyle:on
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
new file mode 100644
index 0000000000..5a0b4d47d6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
+import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
+
+case class TungstenAggregate(
+ requiredChildDistributionExpressions: Option[Seq[Expression]],
+ groupingExpressions: Seq[NamedExpression],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryNode {
+
+ override def outputsUnsafeRows: Boolean = true
+
+ override def canProcessUnsafeRows: Boolean = true
+
+ override def canProcessSafeRows: Boolean = false
+
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ override def requiredChildDistribution: List[Distribution] = {
+ requiredChildDistributionExpressions match {
+ case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+ case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+ case None => UnspecifiedDistribution :: Nil
+ }
+ }
+
+ // This is for testing. We force TungstenAggregationIterator to fall back to sort-based
+ // aggregation once it has processed a given number of input rows.
+ private val testFallbackStartsAt: Option[Int] = {
+ sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match {
+ case null | "" => None
+ case fallbackStartsAt => Some(fallbackStartsAt.toInt)
+ }
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ child.execute().mapPartitions { iter =>
+ val hasInput = iter.hasNext
+ if (!hasInput && groupingExpressions.nonEmpty) {
+ // This is a grouped aggregate and the input iterator is empty,
+ // so return an empty iterator.
+ Iterator.empty.asInstanceOf[Iterator[UnsafeRow]]
+ } else {
+ val aggregationIterator =
+ new TungstenAggregationIterator(
+ groupingExpressions,
+ nonCompleteAggregateExpressions,
+ completeAggregateExpressions,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection,
+ child.output,
+ iter.asInstanceOf[Iterator[UnsafeRow]],
+ testFallbackStartsAt)
+
+ if (!hasInput && groupingExpressions.isEmpty) {
+ Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
+ } else {
+ aggregationIterator
+ }
+ }
+ }
+ }
+
+ override def simpleString: String = {
+ val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
+
+ testFallbackStartsAt match {
+ case None => s"TungstenAggregate ${groupingExpressions} ${allAggregateExpressions}"
+ case Some(fallbackStartsAt) =>
+ s"TungstenAggregateWithControlledFallback ${groupingExpressions} " +
+ s"${allAggregateExpressions} fallbackStartsAt=$fallbackStartsAt"
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
new file mode 100644
index 0000000000..b9d44aace1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -0,0 +1,667 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.unsafe.KVIterator
+import org.apache.spark.{Logging, SparkEnv, TaskContext}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
+import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s.
+ *
+ * This iterator first uses hash-based aggregation to process input rows. It uses
+ * a hash map to store groups and their corresponding aggregation buffers. If we
+ * this map cannot allocate memory from [[org.apache.spark.shuffle.ShuffleMemoryManager]],
+ * it switches to sort-based aggregation. The process of the switch has the following step:
+ * - Step 1: Sort all entries of the hash map based on values of grouping expressions and
+ * spill them to disk.
+ * - Step 2: Create a external sorter based on the spilled sorted map entries.
+ * - Step 3: Redirect all input rows to the external sorter.
+ * - Step 4: Get a sorted [[KVIterator]] from the external sorter.
+ * - Step 5: Initialize sort-based aggregation.
+ * Then, this iterator works in the way of sort-based aggregation.
+ *
+ * The code of this class is organized as follows:
+ * - Part 1: Initializing aggregate functions.
+ * - Part 2: Methods and fields used by setting aggregation buffer values,
+ * processing input rows from inputIter, and generating output
+ * rows.
+ * - Part 3: Methods and fields used by hash-based aggregation.
+ * - Part 4: The function used to switch this iterator from hash-based
+ * aggregation to sort-based aggregation.
+ * - Part 5: Methods and fields used by sort-based aggregation.
+ * - Part 6: Loads input and process input rows.
+ * - Part 7: Public methods of this iterator.
+ * - Part 8: A utility function used to generate a result when there is no
+ * input and there is no grouping expression.
+ *
+ * @param groupingExpressions
+ * expressions for grouping keys
+ * @param nonCompleteAggregateExpressions
+ * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]],
+ * [[PartialMerge]], or [[Final]].
+ * @param completeAggregateExpressions
+ * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]].
+ * @param initialInputBufferOffset
+ * If this iterator is used to handle functions with mode [[PartialMerge]] or [[Final]].
+ * The input rows have the format of `grouping keys + aggregation buffer`.
+ * This offset indicates the starting position of aggregation buffer in a input row.
+ * @param resultExpressions
+ * expressions for generating output rows.
+ * @param newMutableProjection
+ * the function used to create mutable projections.
+ * @param originalInputAttributes
+ * attributes of representing input rows from `inputIter`.
+ * @param inputIter
+ * the iterator containing input [[UnsafeRow]]s.
+ */
+class TungstenAggregationIterator(
+ groupingExpressions: Seq[NamedExpression],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ originalInputAttributes: Seq[Attribute],
+ inputIter: Iterator[UnsafeRow],
+ testFallbackStartsAt: Option[Int])
+ extends Iterator[UnsafeRow] with Logging {
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 1: Initializing aggregate functions.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // A Seq containing all AggregateExpressions.
+ // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final
+ // are at the beginning of the allAggregateExpressions.
+ private[this] val allAggregateExpressions: Seq[AggregateExpression2] =
+ nonCompleteAggregateExpressions ++ completeAggregateExpressions
+
+ // Check to make sure we do not have more than three modes in our AggregateExpressions.
+ // If we have, users are hitting a bug and we throw an IllegalStateException.
+ if (allAggregateExpressions.map(_.mode).distinct.length > 2) {
+ throw new IllegalStateException(
+ s"$allAggregateExpressions should have no more than 2 kinds of modes.")
+ }
+
+ //
+ // The modes of AggregateExpressions. Right now, we can handle the following mode:
+ // - Partial-only:
+ // All AggregateExpressions have the mode of Partial.
+ // For this case, aggregationMode is (Some(Partial), None).
+ // - PartialMerge-only:
+ // All AggregateExpressions have the mode of PartialMerge).
+ // For this case, aggregationMode is (Some(PartialMerge), None).
+ // - Final-only:
+ // All AggregateExpressions have the mode of Final.
+ // For this case, aggregationMode is (Some(Final), None).
+ // - Final-Complete:
+ // Some AggregateExpressions have the mode of Final and
+ // others have the mode of Complete. For this case,
+ // aggregationMode is (Some(Final), Some(Complete)).
+ // - Complete-only:
+ // nonCompleteAggregateExpressions is empty and we have AggregateExpressions
+ // with mode Complete in completeAggregateExpressions. For this case,
+ // aggregationMode is (None, Some(Complete)).
+ // - Grouping-only:
+ // There is no AggregateExpression. For this case, AggregationMode is (None,None).
+ //
+ private[this] var aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = {
+ nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
+ completeAggregateExpressions.map(_.mode).distinct.headOption
+ }
+
+ // All aggregate functions. TungstenAggregationIterator only handles AlgebraicAggregates.
+ // If there is any functions that is not an AlgebraicAggregate, we throw an
+ // IllegalStateException.
+ private[this] val allAggregateFunctions: Array[AlgebraicAggregate] = {
+ if (!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])) {
+ throw new IllegalStateException(
+ "Only AlgebraicAggregates should be passed in TungstenAggregationIterator.")
+ }
+
+ allAggregateExpressions
+ .map(_.aggregateFunction.asInstanceOf[AlgebraicAggregate])
+ .toArray
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 2: Methods and fields used by setting aggregation buffer values,
+ // processing input rows from inputIter, and generating output
+ // rows.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // The projection used to initialize buffer values.
+ private[this] val algebraicInitialProjection: MutableProjection = {
+ val initExpressions = allAggregateFunctions.flatMap(_.initialValues)
+ newMutableProjection(initExpressions, Nil)()
+ }
+
+ // Creates a new aggregation buffer and initializes buffer values.
+ // This functions should be only called at most three times (when we create the hash map,
+ // when we switch to sort-based aggregation, and when we create the re-used buffer for
+ // sort-based aggregation).
+ private def createNewAggregationBuffer(): UnsafeRow = {
+ val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+ val bufferRowSize: Int = bufferSchema.length
+
+ val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
+ val unsafeProjection =
+ UnsafeProjection.create(bufferSchema.map(_.dataType))
+ val buffer = unsafeProjection.apply(genericMutableBuffer)
+ algebraicInitialProjection.target(buffer)(EmptyRow)
+ buffer
+ }
+
+ // Creates a function used to process a row based on the given inputAttributes.
+ private def generateProcessRow(
+ inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = {
+
+ val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
+ val aggregationBufferSchema = StructType.fromAttributes(aggregationBufferAttributes)
+ val inputSchema = StructType.fromAttributes(inputAttributes)
+ val unsafeRowJoiner =
+ GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema)
+
+ aggregationMode match {
+ // Partial-only
+ case (Some(Partial), None) =>
+ val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions)
+ val algebraicUpdateProjection =
+ newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
+
+ (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+ algebraicUpdateProjection.target(currentBuffer)
+ algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row))
+ }
+
+ // PartialMerge-only or Final-only
+ case (Some(PartialMerge), None) | (Some(Final), None) =>
+ val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions)
+ // This projection is used to merge buffer values for all AlgebraicAggregates.
+ val algebraicMergeProjection =
+ newMutableProjection(
+ mergeExpressions,
+ aggregationBufferAttributes ++ inputAttributes)()
+
+ (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+ // Process all algebraic aggregate functions.
+ algebraicMergeProjection.target(currentBuffer)
+ algebraicMergeProjection(unsafeRowJoiner.join(currentBuffer, row))
+ }
+
+ // Final-Complete
+ case (Some(Final), Some(Complete)) =>
+ val nonCompleteAggregateFunctions: Array[AlgebraicAggregate] =
+ allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
+ val completeAggregateFunctions: Array[AlgebraicAggregate] =
+ allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+
+ val completeOffsetExpressions =
+ Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+ val mergeExpressions =
+ nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions
+ val finalAlgebraicMergeProjection =
+ newMutableProjection(
+ mergeExpressions,
+ aggregationBufferAttributes ++ inputAttributes)()
+
+ // We do not touch buffer values of aggregate functions with the Final mode.
+ val finalOffsetExpressions =
+ Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+ val updateExpressions =
+ finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions)
+ val completeAlgebraicUpdateProjection =
+ newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
+
+ (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+ val input = unsafeRowJoiner.join(currentBuffer, row)
+ // For all aggregate functions with mode Complete, update the given currentBuffer.
+ completeAlgebraicUpdateProjection.target(currentBuffer)(input)
+
+ // For all aggregate functions with mode Final, merge buffer values in row to
+ // currentBuffer.
+ finalAlgebraicMergeProjection.target(currentBuffer)(input)
+ }
+
+ // Complete-only
+ case (None, Some(Complete)) =>
+ val completeAggregateFunctions: Array[AlgebraicAggregate] =
+ allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+
+ val updateExpressions =
+ completeAggregateFunctions.flatMap(_.updateExpressions)
+ val completeAlgebraicUpdateProjection =
+ newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
+
+ (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+ completeAlgebraicUpdateProjection.target(currentBuffer)
+ // For all aggregate functions with mode Complete, update the given currentBuffer.
+ completeAlgebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row))
+ }
+
+ // Grouping only.
+ case (None, None) => (currentBuffer: UnsafeRow, row: UnsafeRow) => {}
+
+ case other =>
+ throw new IllegalStateException(
+ s"${aggregationMode} should not be passed into TungstenAggregationIterator.")
+ }
+ }
+
+ // Creates a function used to generate output rows.
+ private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = {
+
+ val groupingAttributes = groupingExpressions.map(_.toAttribute)
+ val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
+ val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
+ val bufferSchema = StructType.fromAttributes(bufferAttributes)
+ val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
+
+ aggregationMode match {
+ // Partial-only or PartialMerge-only: every output row is basically the values of
+ // the grouping expressions and the corresponding aggregation buffer.
+ case (Some(Partial), None) | (Some(PartialMerge), None) =>
+ (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
+ unsafeRowJoiner.join(currentGroupingKey, currentBuffer)
+ }
+
+ // Final-only, Complete-only and Final-Complete: a output row is generated based on
+ // resultExpressions.
+ case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
+ val resultProjection =
+ UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes)
+
+ (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
+ resultProjection(unsafeRowJoiner.join(currentGroupingKey, currentBuffer))
+ }
+
+ // Grouping-only: a output row is generated from values of grouping expressions.
+ case (None, None) =>
+ val resultProjection =
+ UnsafeProjection.create(resultExpressions, groupingAttributes)
+
+ (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
+ resultProjection(currentGroupingKey)
+ }
+
+ case other =>
+ throw new IllegalStateException(
+ s"${aggregationMode} should not be passed into TungstenAggregationIterator.")
+ }
+ }
+
+ // An UnsafeProjection used to extract grouping keys from the input rows.
+ private[this] val groupProjection =
+ UnsafeProjection.create(groupingExpressions, originalInputAttributes)
+
+ // A function used to process a input row. Its first argument is the aggregation buffer
+ // and the second argument is the input row.
+ private[this] var processRow: (UnsafeRow, UnsafeRow) => Unit =
+ generateProcessRow(originalInputAttributes)
+
+ // A function used to generate output rows based on the grouping keys (first argument)
+ // and the corresponding aggregation buffer (second argument).
+ private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow =
+ generateResultProjection()
+
+ // An aggregation buffer containing initial buffer values. It is used to
+ // initialize other aggregation buffers.
+ private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer()
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 3: Methods and fields used by hash-based aggregation.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // This is the hash map used for hash-based aggregation. It is backed by an
+ // UnsafeFixedWidthAggregationMap and it is used to store
+ // all groups and their corresponding aggregation buffers for hash-based aggregation.
+ private[this] val hashMap = new UnsafeFixedWidthAggregationMap(
+ initialAggregationBuffer,
+ StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)),
+ StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
+ TaskContext.get.taskMemoryManager(),
+ SparkEnv.get.shuffleMemoryManager,
+ 1024 * 16, // initial capacity
+ SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"),
+ false // disable tracking of performance metrics
+ )
+
+ // The function used to read and process input rows. When processing input rows,
+ // it first uses hash-based aggregation by putting groups and their buffers in
+ // hashMap. If we could not allocate more memory for the map, we switch to
+ // sort-based aggregation (by calling switchToSortBasedAggregation).
+ private def processInputs(): Unit = {
+ while (!sortBased && inputIter.hasNext) {
+ val newInput = inputIter.next()
+ val groupingKey = groupProjection.apply(newInput)
+ val buffer: UnsafeRow = hashMap.getAggregationBuffer(groupingKey)
+ if (buffer == null) {
+ // buffer == null means that we could not allocate more memory.
+ // Now, we need to spill the map and switch to sort-based aggregation.
+ switchToSortBasedAggregation(groupingKey, newInput)
+ } else {
+ processRow(buffer, newInput)
+ }
+ }
+ }
+
+ // This function is only used for testing. It basically the same as processInputs except
+ // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have
+ // been processed.
+ private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = {
+ var i = 0
+ while (!sortBased && inputIter.hasNext) {
+ val newInput = inputIter.next()
+ val groupingKey = groupProjection.apply(newInput)
+ val buffer: UnsafeRow = if (i < fallbackStartsAt) {
+ hashMap.getAggregationBuffer(groupingKey)
+ } else {
+ null
+ }
+ if (buffer == null) {
+ // buffer == null means that we could not allocate more memory.
+ // Now, we need to spill the map and switch to sort-based aggregation.
+ switchToSortBasedAggregation(groupingKey, newInput)
+ } else {
+ processRow(buffer, newInput)
+ }
+ i += 1
+ }
+ }
+
+ // The iterator created from hashMap. It is used to generate output rows when we
+ // are using hash-based aggregation.
+ private[this] var aggregationBufferMapIterator: KVIterator[UnsafeRow, UnsafeRow] = null
+
+ // Indicates if aggregationBufferMapIterator still has key-value pairs.
+ private[this] var mapIteratorHasNext: Boolean = false
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 4: The function used to switch this iterator from hash-based
+ // aggregation to sort-based aggregation.
+ ///////////////////////////////////////////////////////////////////////////
+
+ private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = {
+ logInfo("falling back to sort based aggregation.")
+ // Step 1: Get the ExternalSorter containing sorted entries of the map.
+ val externalSorter: UnsafeKVExternalSorter = hashMap.destructAndCreateExternalSorter()
+
+ // Step 2: Free the memory used by the map.
+ hashMap.free()
+
+ // Step 3: If we have aggregate function with mode Partial or Complete,
+ // we need to process input rows to get aggregation buffer.
+ // So, later in the sort-based aggregation iterator, we can do merge.
+ // If aggregate functions are with mode Final and PartialMerge,
+ // we just need to project the aggregation buffer from an input row.
+ val needsProcess = aggregationMode match {
+ case (Some(Partial), None) => true
+ case (None, Some(Complete)) => true
+ case (Some(Final), Some(Complete)) => true
+ case _ => false
+ }
+
+ if (needsProcess) {
+ // First, we create a buffer.
+ val buffer = createNewAggregationBuffer()
+
+ // Process firstKey and firstInput.
+ // Initialize buffer.
+ buffer.copyFrom(initialAggregationBuffer)
+ processRow(buffer, firstInput)
+ externalSorter.insertKV(firstKey, buffer)
+
+ // Process the rest of input rows.
+ while (inputIter.hasNext) {
+ val newInput = inputIter.next()
+ val groupingKey = groupProjection.apply(newInput)
+ buffer.copyFrom(initialAggregationBuffer)
+ processRow(buffer, newInput)
+ externalSorter.insertKV(groupingKey, buffer)
+ }
+ } else {
+ // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer.
+ // We need to project the aggregation buffer part from an input row.
+ val buffer = createNewAggregationBuffer()
+ // The originalInputAttributes are using cloneBufferAttributes. So, we need to use
+ // allAggregateFunctions.flatMap(_.cloneBufferAttributes).
+ val bufferExtractor = newMutableProjection(
+ allAggregateFunctions.flatMap(_.cloneBufferAttributes),
+ originalInputAttributes)()
+ bufferExtractor.target(buffer)
+
+ // Insert firstKey and its buffer.
+ bufferExtractor(firstInput)
+ externalSorter.insertKV(firstKey, buffer)
+
+ // Insert the rest of input rows.
+ while (inputIter.hasNext) {
+ val newInput = inputIter.next()
+ val groupingKey = groupProjection.apply(newInput)
+ bufferExtractor(newInput)
+ externalSorter.insertKV(groupingKey, buffer)
+ }
+ }
+
+ // Set aggregationMode, processRow, and generateOutput for sort-based aggregation.
+ val newAggregationMode = aggregationMode match {
+ case (Some(Partial), None) => (Some(PartialMerge), None)
+ case (None, Some(Complete)) => (Some(Final), None)
+ case (Some(Final), Some(Complete)) => (Some(Final), None)
+ case other => other
+ }
+ aggregationMode = newAggregationMode
+
+ // Basically the value of the KVIterator returned by externalSorter
+ // will just aggregation buffer. At here, we use cloneBufferAttributes.
+ val newInputAttributes: Seq[Attribute] =
+ allAggregateFunctions.flatMap(_.cloneBufferAttributes)
+
+ // Set up new processRow and generateOutput.
+ processRow = generateProcessRow(newInputAttributes)
+ generateOutput = generateResultProjection()
+
+ // Step 5: Get the sorted iterator from the externalSorter.
+ sortedKVIterator = externalSorter.sortedIterator()
+
+ // Step 6: Pre-load the first key-value pair from the sorted iterator to make
+ // hasNext idempotent.
+ sortedInputHasNewGroup = sortedKVIterator.next()
+
+ // Copy the first key and value (aggregation buffer).
+ if (sortedInputHasNewGroup) {
+ val key = sortedKVIterator.getKey
+ val value = sortedKVIterator.getValue
+ nextGroupingKey = key.copy()
+ currentGroupingKey = key.copy()
+ firstRowInNextGroup = value.copy()
+ }
+
+ // Step 7: set sortBased to true.
+ sortBased = true
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 5: Methods and fields used by sort-based aggregation.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // Indicates if we are using sort-based aggregation. Because we first try to use
+ // hash-based aggregation, its initial value is false.
+ private[this] var sortBased: Boolean = false
+
+ // The KVIterator containing input rows for the sort-based aggregation. It will be
+ // set in switchToSortBasedAggregation when we switch to sort-based aggregation.
+ private[this] var sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = null
+
+ // The grouping key of the current group.
+ private[this] var currentGroupingKey: UnsafeRow = null
+
+ // The grouping key of next group.
+ private[this] var nextGroupingKey: UnsafeRow = null
+
+ // The first row of next group.
+ private[this] var firstRowInNextGroup: UnsafeRow = null
+
+ // Indicates if we has new group of rows from the sorted input iterator.
+ private[this] var sortedInputHasNewGroup: Boolean = false
+
+ // The aggregation buffer used by the sort-based aggregation.
+ private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer()
+
+ // Processes rows in the current group. It will stop when it find a new group.
+ private def processCurrentSortedGroup(): Unit = {
+ // First, we need to copy nextGroupingKey to currentGroupingKey.
+ currentGroupingKey.copyFrom(nextGroupingKey)
+ // Now, we will start to find all rows belonging to this group.
+ // We create a variable to track if we see the next group.
+ var findNextPartition = false
+ // firstRowInNextGroup is the first row of this group. We first process it.
+ processRow(sortBasedAggregationBuffer, firstRowInNextGroup)
+
+ // The search will stop when we see the next group or there is no
+ // input row left in the iter.
+ // Pre-load the first key-value pair to make the condition of the while loop
+ // has no action (we do not trigger loading a new key-value pair
+ // when we evaluate the condition).
+ var hasNext = sortedKVIterator.next()
+ while (!findNextPartition && hasNext) {
+ // Get the grouping key and value (aggregation buffer).
+ val groupingKey = sortedKVIterator.getKey
+ val inputAggregationBuffer = sortedKVIterator.getValue
+
+ // Check if the current row belongs the current input row.
+ if (currentGroupingKey.equals(groupingKey)) {
+ processRow(sortBasedAggregationBuffer, inputAggregationBuffer)
+
+ hasNext = sortedKVIterator.next()
+ } else {
+ // We find a new group.
+ findNextPartition = true
+ // copyFrom will fail when
+ nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy()
+ firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = inputAggregationBuffer.copy()
+
+ }
+ }
+ // We have not seen a new group. It means that there is no new row in the input
+ // iter. The current group is the last group of the sortedKVIterator.
+ if (!findNextPartition) {
+ sortedInputHasNewGroup = false
+ sortedKVIterator.close()
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 6: Loads input rows and setup aggregationBufferMapIterator if we
+ // have not switched to sort-based aggregation.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // Starts to process input rows.
+ testFallbackStartsAt match {
+ case None =>
+ processInputs()
+ case Some(fallbackStartsAt) =>
+ // This is the testing path. processInputsWithControlledFallback is same as processInputs
+ // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
+ // have been processed.
+ processInputsWithControlledFallback(fallbackStartsAt)
+ }
+
+ // If we did not switch to sort-based aggregation in processInputs,
+ // we pre-load the first key-value pair from the map (to make hasNext idempotent).
+ if (!sortBased) {
+ // First, set aggregationBufferMapIterator.
+ aggregationBufferMapIterator = hashMap.iterator()
+ // Pre-load the first key-value pair from the aggregationBufferMapIterator.
+ mapIteratorHasNext = aggregationBufferMapIterator.next()
+ // If the map is empty, we just free it.
+ if (!mapIteratorHasNext) {
+ hashMap.free()
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Par 7: Iterator's public methods.
+ ///////////////////////////////////////////////////////////////////////////
+
+ override final def hasNext: Boolean = {
+ (sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext)
+ }
+
+ override final def next(): UnsafeRow = {
+ if (hasNext) {
+ if (sortBased) {
+ // Process the current group.
+ processCurrentSortedGroup()
+ // Generate output row for the current group.
+ val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer)
+ // Initialize buffer values for the next group.
+ sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
+
+ outputRow
+ } else {
+ // We did not fall back to sort-based aggregation.
+ val result =
+ generateOutput(
+ aggregationBufferMapIterator.getKey,
+ aggregationBufferMapIterator.getValue)
+
+ // Pre-load next key-value pair form aggregationBufferMapIterator to make hasNext
+ // idempotent.
+ mapIteratorHasNext = aggregationBufferMapIterator.next()
+
+ if (!mapIteratorHasNext) {
+ // If there is no input from aggregationBufferMapIterator, we copy current result.
+ val resultCopy = result.copy()
+ // Then, we free the map.
+ hashMap.free()
+
+ resultCopy
+ } else {
+ result
+ }
+ }
+ } else {
+ // no more result
+ throw new NoSuchElementException
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 8: A utility function used to generate a output row when there is no
+ // input and there is no grouping expression.
+ ///////////////////////////////////////////////////////////////////////////
+ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
+ if (groupingExpressions.isEmpty) {
+ sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
+ // We create a output row and copy it. So, we can free the map.
+ val resultCopy =
+ generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy()
+ hashMap.free()
+ resultCopy
+ } else {
+ throw new IllegalStateException(
+ "This method should not be called when groupingExpressions is not empty.")
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
deleted file mode 100644
index b465787fe8..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
+++ /dev/null
@@ -1,372 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.unsafe.KVIterator
-import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
-import org.apache.spark.sql.types.StructType
-
-/**
- * An iterator used to evaluate [[AggregateFunction2]].
- * It first tries to use in-memory hash-based aggregation. If we cannot allocate more
- * space for the hash map, we spill the sorted map entries, free the map, and then
- * switch to sort-based aggregation.
- */
-class UnsafeHybridAggregationIterator(
- groupingKeyAttributes: Seq[Attribute],
- valueAttributes: Seq[Attribute],
- inputKVIterator: KVIterator[UnsafeRow, InternalRow],
- nonCompleteAggregateExpressions: Seq[AggregateExpression2],
- nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
- completeAggregateAttributes: Seq[Attribute],
- initialInputBufferOffset: Int,
- resultExpressions: Seq[NamedExpression],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- outputsUnsafeRows: Boolean)
- extends AggregationIterator(
- groupingKeyAttributes,
- valueAttributes,
- nonCompleteAggregateExpressions,
- nonCompleteAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
- initialInputBufferOffset,
- resultExpressions,
- newMutableProjection,
- outputsUnsafeRows) {
-
- require(groupingKeyAttributes.nonEmpty)
-
- ///////////////////////////////////////////////////////////////////////////
- // Unsafe Aggregation buffers
- ///////////////////////////////////////////////////////////////////////////
-
- // This is the Unsafe Aggregation Map used to store all buffers.
- private[this] val buffers = new UnsafeFixedWidthAggregationMap(
- newBuffer,
- StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)),
- StructType.fromAttributes(groupingKeyAttributes),
- TaskContext.get.taskMemoryManager(),
- SparkEnv.get.shuffleMemoryManager,
- 1024 * 16, // initial capacity
- SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"),
- false // disable tracking of performance metrics
- )
-
- override protected def newBuffer: UnsafeRow = {
- val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
- val bufferRowSize: Int = bufferSchema.length
-
- val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
- val unsafeProjection =
- UnsafeProjection.create(bufferSchema.map(_.dataType))
- val buffer = unsafeProjection.apply(genericMutableBuffer)
- initializeBuffer(buffer)
- buffer
- }
-
- ///////////////////////////////////////////////////////////////////////////
- // Methods and variables related to switching to sort-based aggregation
- ///////////////////////////////////////////////////////////////////////////
- private[this] var sortBased = false
-
- private[this] var sortBasedAggregationIterator: SortBasedAggregationIterator = _
-
- // The value part of the input KV iterator is used to store original input values of
- // aggregate functions, we need to convert them to aggregation buffers.
- private def processOriginalInput(
- firstKey: UnsafeRow,
- firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = {
- new KVIterator[UnsafeRow, UnsafeRow] {
- private[this] var isFirstRow = true
-
- private[this] var groupingKey: UnsafeRow = _
-
- private[this] val buffer: UnsafeRow = newBuffer
-
- override def next(): Boolean = {
- initializeBuffer(buffer)
- if (isFirstRow) {
- isFirstRow = false
- groupingKey = firstKey
- processRow(buffer, firstValue)
-
- true
- } else if (inputKVIterator.next()) {
- groupingKey = inputKVIterator.getKey()
- val value = inputKVIterator.getValue()
- processRow(buffer, value)
-
- true
- } else {
- false
- }
- }
-
- override def getKey(): UnsafeRow = {
- groupingKey
- }
-
- override def getValue(): UnsafeRow = {
- buffer
- }
-
- override def close(): Unit = {
- // Do nothing.
- }
- }
- }
-
- // The value of the input KV Iterator has the format of groupingExprs + aggregation buffer.
- // We need to project the aggregation buffer out.
- private def projectInputBufferToUnsafe(
- firstKey: UnsafeRow,
- firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = {
- new KVIterator[UnsafeRow, UnsafeRow] {
- private[this] var isFirstRow = true
-
- private[this] var groupingKey: UnsafeRow = _
-
- private[this] val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
-
- private[this] val value: UnsafeRow = {
- val genericMutableRow = new GenericMutableRow(bufferSchema.length)
- UnsafeProjection.create(bufferSchema.map(_.dataType)).apply(genericMutableRow)
- }
-
- private[this] val projectInputBuffer = {
- newMutableProjection(bufferSchema, valueAttributes)().target(value)
- }
-
- override def next(): Boolean = {
- if (isFirstRow) {
- isFirstRow = false
- groupingKey = firstKey
- projectInputBuffer(firstValue)
-
- true
- } else if (inputKVIterator.next()) {
- groupingKey = inputKVIterator.getKey()
- projectInputBuffer(inputKVIterator.getValue())
-
- true
- } else {
- false
- }
- }
-
- override def getKey(): UnsafeRow = {
- groupingKey
- }
-
- override def getValue(): UnsafeRow = {
- value
- }
-
- override def close(): Unit = {
- // Do nothing.
- }
- }
- }
-
- /**
- * We need to fall back to sort based aggregation because we do not have enough memory
- * for our in-memory hash map (i.e. `buffers`).
- */
- private def switchToSortBasedAggregation(
- currentGroupingKey: UnsafeRow,
- currentRow: InternalRow): Unit = {
- logInfo("falling back to sort based aggregation.")
-
- // Step 1: Get the ExternalSorter containing entries of the map.
- val externalSorter = buffers.destructAndCreateExternalSorter()
-
- // Step 2: Free the memory used by the map.
- buffers.free()
-
- // Step 3: If we have aggregate function with mode Partial or Complete,
- // we need to process them to get aggregation buffer.
- // So, later in the sort-based aggregation iterator, we can do merge.
- // If aggregate functions are with mode Final and PartialMerge,
- // we just need to project the aggregation buffer from the input.
- val needsProcess = aggregationMode match {
- case (Some(Partial), None) => true
- case (None, Some(Complete)) => true
- case (Some(Final), Some(Complete)) => true
- case _ => false
- }
-
- val processedIterator = if (needsProcess) {
- processOriginalInput(currentGroupingKey, currentRow)
- } else {
- // The input value's format is groupingExprs + buffer.
- // We need to project the buffer part out.
- projectInputBufferToUnsafe(currentGroupingKey, currentRow)
- }
-
- // Step 4: Redirect processedIterator to externalSorter.
- while (processedIterator.next()) {
- externalSorter.insertKV(processedIterator.getKey(), processedIterator.getValue())
- }
-
- // Step 5: Get the sorted iterator from the externalSorter.
- val sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = externalSorter.sortedIterator()
-
- // Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator.
- // For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator
- // will be PartialMerge. For a aggregate function with mode Complete,
- // its mode in the SortBasedAggregationIterator will be Final.
- val newNonCompleteAggregateExpressions = allAggregateExpressions.map {
- case AggregateExpression2(func, Partial, isDistinct) =>
- AggregateExpression2(func, PartialMerge, isDistinct)
- case AggregateExpression2(func, Complete, isDistinct) =>
- AggregateExpression2(func, Final, isDistinct)
- case other => other
- }
- val newNonCompleteAggregateAttributes =
- nonCompleteAggregateAttributes ++ completeAggregateAttributes
-
- val newValueAttributes =
- allAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
-
- sortBasedAggregationIterator = SortBasedAggregationIterator.createFromKVIterator(
- groupingKeyAttributes = groupingKeyAttributes,
- valueAttributes = newValueAttributes,
- inputKVIterator = sortedKVIterator.asInstanceOf[KVIterator[InternalRow, InternalRow]],
- nonCompleteAggregateExpressions = newNonCompleteAggregateExpressions,
- nonCompleteAggregateAttributes = newNonCompleteAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- initialInputBufferOffset = 0,
- resultExpressions = resultExpressions,
- newMutableProjection = newMutableProjection,
- outputsUnsafeRows = outputsUnsafeRows)
- }
-
- ///////////////////////////////////////////////////////////////////////////
- // Methods used to initialize this iterator.
- ///////////////////////////////////////////////////////////////////////////
-
- /** Starts to read input rows and falls back to sort-based aggregation if necessary. */
- protected def initialize(): Unit = {
- var hasNext = inputKVIterator.next()
- while (!sortBased && hasNext) {
- val groupingKey = inputKVIterator.getKey()
- val currentRow = inputKVIterator.getValue()
- val buffer = buffers.getAggregationBuffer(groupingKey)
- if (buffer == null) {
- // buffer == null means that we could not allocate more memory.
- // Now, we need to spill the map and switch to sort-based aggregation.
- switchToSortBasedAggregation(groupingKey, currentRow)
- sortBased = true
- } else {
- processRow(buffer, currentRow)
- hasNext = inputKVIterator.next()
- }
- }
- }
-
- // This is the starting point of this iterator.
- initialize()
-
- // Creates the iterator for the Hash Aggregation Map after we have populated
- // contents of that map.
- private[this] val aggregationBufferMapIterator = buffers.iterator()
-
- private[this] var _mapIteratorHasNext = false
-
- // Pre-load the first key-value pair from the map to make hasNext idempotent.
- if (!sortBased) {
- _mapIteratorHasNext = aggregationBufferMapIterator.next()
- // If the map is empty, we just free it.
- if (!_mapIteratorHasNext) {
- buffers.free()
- }
- }
-
- ///////////////////////////////////////////////////////////////////////////
- // Iterator's public methods
- ///////////////////////////////////////////////////////////////////////////
-
- override final def hasNext: Boolean = {
- (sortBased && sortBasedAggregationIterator.hasNext) || (!sortBased && _mapIteratorHasNext)
- }
-
-
- override final def next(): InternalRow = {
- if (hasNext) {
- if (sortBased) {
- sortBasedAggregationIterator.next()
- } else {
- // We did not fall back to the sort-based aggregation.
- val result =
- generateOutput(
- aggregationBufferMapIterator.getKey,
- aggregationBufferMapIterator.getValue)
- // Pre-load next key-value pair form aggregationBufferMapIterator.
- _mapIteratorHasNext = aggregationBufferMapIterator.next()
-
- if (!_mapIteratorHasNext) {
- val resultCopy = result.copy()
- buffers.free()
- resultCopy
- } else {
- result
- }
- }
- } else {
- // no more result
- throw new NoSuchElementException
- }
- }
-}
-
-object UnsafeHybridAggregationIterator {
- // scalastyle:off
- def createFromInputIterator(
- groupingExprs: Seq[NamedExpression],
- nonCompleteAggregateExpressions: Seq[AggregateExpression2],
- nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
- completeAggregateAttributes: Seq[Attribute],
- initialInputBufferOffset: Int,
- resultExpressions: Seq[NamedExpression],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow],
- outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = {
- new UnsafeHybridAggregationIterator(
- groupingExprs.map(_.toAttribute),
- inputAttributes,
- AggregationIterator.unsafeKVIterator(groupingExprs, inputAttributes, inputIter),
- nonCompleteAggregateExpressions,
- nonCompleteAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
- initialInputBufferOffset,
- resultExpressions,
- newMutableProjection,
- outputsUnsafeRows)
- }
- // scalastyle:on
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 960be08f84..80816a095e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -17,20 +17,41 @@
package org.apache.spark.sql.execution.aggregate
+import scala.collection.mutable
+
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan}
+import org.apache.spark.sql.types.StructType
/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
*/
object Utils {
+ def supportsTungstenAggregate(
+ groupingExpressions: Seq[Expression],
+ aggregateBufferAttributes: Seq[Attribute]): Boolean = {
+ val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
+
+ UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
+ UnsafeProjection.canSupport(groupingExpressions)
+ }
+
def planAggregateWithoutDistinct(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[AggregateExpression2],
- aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+ aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
+ // Check if we can use TungstenAggregate.
+ val usesTungstenAggregate =
+ child.sqlContext.conf.unsafeEnabled &&
+ aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) &&
+ supportsTungstenAggregate(
+ groupingExpressions,
+ aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
+
+
// 1. Create an Aggregate Operator for partial aggregations.
val namedGroupingExpressions = groupingExpressions.map {
case ne: NamedExpression => ne -> ne
@@ -44,11 +65,23 @@ object Utils {
val groupExpressionMap = namedGroupingExpressions.toMap
val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial))
- val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
- agg.aggregateFunction.bufferAttributes
- }
- val partialAggregate =
- Aggregate(
+ val partialAggregateAttributes =
+ partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
+ val partialResultExpressions =
+ namedGroupingAttributes ++
+ partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+
+ val partialAggregate = if (usesTungstenAggregate) {
+ TungstenAggregate(
+ requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+ groupingExpressions = namedGroupingExpressions.map(_._2),
+ nonCompleteAggregateExpressions = partialAggregateExpressions,
+ completeAggregateExpressions = Nil,
+ initialInputBufferOffset = 0,
+ resultExpressions = partialResultExpressions,
+ child = child)
+ } else {
+ SortBasedAggregate(
requiredChildDistributionExpressions = None: Option[Seq[Expression]],
groupingExpressions = namedGroupingExpressions.map(_._2),
nonCompleteAggregateExpressions = partialAggregateExpressions,
@@ -56,29 +89,57 @@ object Utils {
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
initialInputBufferOffset = 0,
- resultExpressions = namedGroupingAttributes ++ partialAggregateAttributes,
+ resultExpressions = partialResultExpressions,
child = child)
+ }
// 2. Create an Aggregate Operator for final aggregations.
val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
val finalAggregateAttributes =
finalAggregateExpressions.map {
- expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+ expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
}
- val rewrittenResultExpressions = resultExpressions.map { expr =>
- expr.transformDown {
- case agg: AggregateExpression2 =>
- aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
- case expression =>
- // We do not rely on the equality check at here since attributes may
- // different cosmetically. Instead, we use semanticEquals.
- groupExpressionMap.collectFirst {
- case (expr, ne) if expr semanticEquals expression => ne.toAttribute
- }.getOrElse(expression)
- }.asInstanceOf[NamedExpression]
- }
- val finalAggregate =
- Aggregate(
+
+ val finalAggregate = if (usesTungstenAggregate) {
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transformDown {
+ case agg: AggregateExpression2 =>
+ // aggregateFunctionMap contains unique aggregate functions.
+ val aggregateFunction =
+ aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1
+ aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression
+ case expression =>
+ // We do not rely on the equality check at here since attributes may
+ // different cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+
+ TungstenAggregate(
+ requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+ groupingExpressions = namedGroupingAttributes,
+ nonCompleteAggregateExpressions = finalAggregateExpressions,
+ completeAggregateExpressions = Nil,
+ initialInputBufferOffset = namedGroupingAttributes.length,
+ resultExpressions = rewrittenResultExpressions,
+ child = partialAggregate)
+ } else {
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transformDown {
+ case agg: AggregateExpression2 =>
+ aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
+ case expression =>
+ // We do not rely on the equality check at here since attributes may
+ // different cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+
+ SortBasedAggregate(
requiredChildDistributionExpressions = Some(namedGroupingAttributes),
groupingExpressions = namedGroupingAttributes,
nonCompleteAggregateExpressions = finalAggregateExpressions,
@@ -88,6 +149,7 @@ object Utils {
initialInputBufferOffset = namedGroupingAttributes.length,
resultExpressions = rewrittenResultExpressions,
child = partialAggregate)
+ }
finalAggregate :: Nil
}
@@ -96,10 +158,18 @@ object Utils {
groupingExpressions: Seq[Expression],
functionsWithDistinct: Seq[AggregateExpression2],
functionsWithoutDistinct: Seq[AggregateExpression2],
- aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+ aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
+ val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct
+ val usesTungstenAggregate =
+ child.sqlContext.conf.unsafeEnabled &&
+ aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) &&
+ supportsTungstenAggregate(
+ groupingExpressions,
+ aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
+
// 1. Create an Aggregate Operator for partial aggregations.
// The grouping expressions are original groupingExpressions and
// distinct columns. For example, for avg(distinct value) ... group by key
@@ -129,19 +199,26 @@ object Utils {
val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap
val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute)
- val partialAggregateExpressions = functionsWithoutDistinct.map {
- case AggregateExpression2(aggregateFunction, mode, _) =>
- AggregateExpression2(aggregateFunction, Partial, false)
- }
- val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
- agg.aggregateFunction.bufferAttributes
- }
+ val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
+ val partialAggregateAttributes =
+ partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
val partialAggregateGroupingExpressions =
(namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2)
val partialAggregateResult =
- namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes
- val partialAggregate =
- Aggregate(
+ namedGroupingAttributes ++
+ distinctColumnAttributes ++
+ partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+ val partialAggregate = if (usesTungstenAggregate) {
+ TungstenAggregate(
+ requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+ groupingExpressions = partialAggregateGroupingExpressions,
+ nonCompleteAggregateExpressions = partialAggregateExpressions,
+ completeAggregateExpressions = Nil,
+ initialInputBufferOffset = 0,
+ resultExpressions = partialAggregateResult,
+ child = child)
+ } else {
+ SortBasedAggregate(
requiredChildDistributionExpressions = None: Option[Seq[Expression]],
groupingExpressions = partialAggregateGroupingExpressions,
nonCompleteAggregateExpressions = partialAggregateExpressions,
@@ -151,20 +228,27 @@ object Utils {
initialInputBufferOffset = 0,
resultExpressions = partialAggregateResult,
child = child)
+ }
// 2. Create an Aggregate Operator for partial merge aggregations.
- val partialMergeAggregateExpressions = functionsWithoutDistinct.map {
- case AggregateExpression2(aggregateFunction, mode, _) =>
- AggregateExpression2(aggregateFunction, PartialMerge, false)
- }
+ val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
val partialMergeAggregateAttributes =
- partialMergeAggregateExpressions.flatMap { agg =>
- agg.aggregateFunction.bufferAttributes
- }
+ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
val partialMergeAggregateResult =
- namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes
- val partialMergeAggregate =
- Aggregate(
+ namedGroupingAttributes ++
+ distinctColumnAttributes ++
+ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+ val partialMergeAggregate = if (usesTungstenAggregate) {
+ TungstenAggregate(
+ requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+ groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
+ nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+ completeAggregateExpressions = Nil,
+ initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+ resultExpressions = partialMergeAggregateResult,
+ child = partialAggregate)
+ } else {
+ SortBasedAggregate(
requiredChildDistributionExpressions = Some(namedGroupingAttributes),
groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
@@ -174,48 +258,91 @@ object Utils {
initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = partialMergeAggregateResult,
child = partialAggregate)
+ }
// 3. Create an Aggregate Operator for partial merge aggregations.
- val finalAggregateExpressions = functionsWithoutDistinct.map {
- case AggregateExpression2(aggregateFunction, mode, _) =>
- AggregateExpression2(aggregateFunction, Final, false)
- }
+ val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
val finalAggregateAttributes =
finalAggregateExpressions.map {
- expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+ expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
}
+ // Create a map to store those rewritten aggregate functions. We always need to use
+ // both function and its corresponding isDistinct flag as the key because function itself
+ // does not knows if it is has distinct keyword or now.
+ val rewrittenAggregateFunctions =
+ mutable.Map.empty[(AggregateFunction2, Boolean), AggregateFunction2]
val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
// to AttributeReferences.
- case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+ case agg @ AggregateExpression2(aggregateFunction, mode, true) =>
val rewrittenAggregateFunction = aggregateFunction.transformDown {
case expr if distinctColumnExpressionMap.contains(expr) =>
distinctColumnExpressionMap(expr).toAttribute
}.asInstanceOf[AggregateFunction2]
+ // Because we have rewritten the aggregate function, we use rewrittenAggregateFunctions
+ // to track the old version and the new version of this function.
+ rewrittenAggregateFunctions += (aggregateFunction, true) -> rewrittenAggregateFunction
// We rewrite the aggregate function to a non-distinct aggregation because
// its input will have distinct arguments.
+ // We just keep the isDistinct setting to true, so when users look at the query plan,
+ // they still can see distinct aggregations.
val rewrittenAggregateExpression =
- AggregateExpression2(rewrittenAggregateFunction, Complete, false)
+ AggregateExpression2(rewrittenAggregateFunction, Complete, true)
- val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct)
+ val aggregateFunctionAttribute =
+ aggregateFunctionMap(agg.aggregateFunction, true)._2
(rewrittenAggregateExpression -> aggregateFunctionAttribute)
}.unzip
- val rewrittenResultExpressions = resultExpressions.map { expr =>
- expr.transform {
- case agg: AggregateExpression2 =>
- aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
- case expression =>
- // We do not rely on the equality check at here since attributes may
- // different cosmetically. Instead, we use semanticEquals.
- groupExpressionMap.collectFirst {
- case (expr, ne) if expr semanticEquals expression => ne.toAttribute
- }.getOrElse(expression)
- }.asInstanceOf[NamedExpression]
- }
- val finalAndCompleteAggregate =
- Aggregate(
+ val finalAndCompleteAggregate = if (usesTungstenAggregate) {
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transform {
+ case agg: AggregateExpression2 =>
+ val function = agg.aggregateFunction
+ val isDistinct = agg.isDistinct
+ val aggregateFunction =
+ if (rewrittenAggregateFunctions.contains(function, isDistinct)) {
+ // If this function has been rewritten, we get the rewritten version from
+ // rewrittenAggregateFunctions.
+ rewrittenAggregateFunctions(function, isDistinct)
+ } else {
+ // Oterwise, we get it from aggregateFunctionMap, which contains unique
+ // aggregate functions that have not been rewritten.
+ aggregateFunctionMap(function, isDistinct)._1
+ }
+ aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression
+ case expression =>
+ // We do not rely on the equality check at here since attributes may
+ // different cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+
+ TungstenAggregate(
+ requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+ groupingExpressions = namedGroupingAttributes,
+ nonCompleteAggregateExpressions = finalAggregateExpressions,
+ completeAggregateExpressions = completeAggregateExpressions,
+ initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+ resultExpressions = rewrittenResultExpressions,
+ child = partialMergeAggregate)
+ } else {
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transform {
+ case agg: AggregateExpression2 =>
+ aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
+ case expression =>
+ // We do not rely on the equality check at here since attributes may
+ // different cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+ SortBasedAggregate(
requiredChildDistributionExpressions = Some(namedGroupingAttributes),
groupingExpressions = namedGroupingAttributes,
nonCompleteAggregateExpressions = finalAggregateExpressions,
@@ -225,6 +352,7 @@ object Utils {
initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = rewrittenResultExpressions,
child = partialMergeAggregate)
+ }
finalAndCompleteAggregate :: Nil
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index cef40dd324..c64aa7a07d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -262,7 +262,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
val df = sql(sqlText)
// First, check if we have GeneratedAggregate.
val hasGeneratedAgg = df.queryExecution.executedPlan
- .collect { case _: aggregate.Aggregate => true }
+ .collect { case _: aggregate.TungstenAggregate => true }
.nonEmpty
if (!hasGeneratedAgg) {
fail(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 4b35c8fd83..7b5aa4763f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -21,9 +21,9 @@ import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row}
+import org.apache.spark.sql._
import org.scalatest.BeforeAndAfterAll
-import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
+import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
@@ -141,6 +141,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
Nil)
}
+ test("null literal") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | AVG(null),
+ | COUNT(null),
+ | FIRST(null),
+ | LAST(null),
+ | MAX(null),
+ | MIN(null),
+ | SUM(null)
+ """.stripMargin),
+ Row(null, 0, null, null, null, null, null) :: Nil)
+ }
+
test("only do grouping") {
checkAnswer(
sqlContext.sql(
@@ -266,13 +282,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
|SELECT avg(value) FROM agg1
""".stripMargin),
Row(11.125) :: Nil)
-
- checkAnswer(
- sqlContext.sql(
- """
- |SELECT avg(null)
- """.stripMargin),
- Row(null) :: Nil)
}
test("udaf") {
@@ -364,7 +373,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
| max(distinct value1)
|FROM agg2
""".stripMargin),
- Row(-60, 70.0, 101.0/9.0, 5.6, 100.0))
+ Row(-60, 70.0, 101.0/9.0, 5.6, 100))
checkAnswer(
sqlContext.sql(
@@ -402,6 +411,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
Row(3, null, 3.0, null, null, null) ::
Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | count(value1),
+ | count(*),
+ | count(1),
+ | count(DISTINCT value1),
+ | key
+ |FROM agg2
+ |GROUP BY key
+ """.stripMargin),
+ Row(3, 3, 3, 2, 1) ::
+ Row(3, 4, 4, 2, 2) ::
+ Row(0, 2, 2, 0, 3) ::
+ Row(3, 4, 4, 3, null) :: Nil)
}
test("test count") {
@@ -496,7 +522,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
|FROM agg1
|GROUP BY key
""".stripMargin).queryExecution.executedPlan.collect {
- case agg: aggregate.Aggregate => agg
+ case agg: aggregate.SortBasedAggregate => agg
+ case agg: aggregate.TungstenAggregate => agg
}
val message =
"We should fallback to the old aggregation code path if " +
@@ -537,3 +564,58 @@ class TungstenAggregationQuerySuite extends AggregationQuerySuite {
sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString)
}
}
+
+class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite {
+
+ var originalUnsafeEnabled: Boolean = _
+
+ override def beforeAll(): Unit = {
+ originalUnsafeEnabled = sqlContext.conf.unsafeEnabled
+ sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true")
+ super.beforeAll()
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString)
+ sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt")
+ }
+
+ override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ (0 to 2).foreach { fallbackStartsAt =>
+ sqlContext.setConf(
+ "spark.sql.TungstenAggregate.testFallbackStartsAt",
+ fallbackStartsAt.toString)
+
+ // Create a new df to make sure its physical operator picks up
+ // spark.sql.TungstenAggregate.testFallbackStartsAt.
+ val newActual = DataFrame(sqlContext, actual.logicalPlan)
+
+ QueryTest.checkAnswer(newActual, expectedAnswer) match {
+ case Some(errorMessage) =>
+ val newErrorMessage =
+ s"""
+ |The following aggregation query failed when using TungstenAggregate with
+ |controlled fallback (it falls back to sort-based aggregation once it has processed
+ |$fallbackStartsAt input rows). The query is
+ |${actual.queryExecution}
+ |
+ |$errorMessage
+ """.stripMargin
+
+ fail(newErrorMessage)
+ case None =>
+ }
+ }
+ }
+
+ // Override it to make sure we call the actually overridden checkAnswer.
+ override protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = {
+ checkAnswer(df, Seq(expectedAnswer))
+ }
+
+ // Override it to make sure we call the actually overridden checkAnswer.
+ override protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = {
+ checkAnswer(df, expectedAnswer.collect())
+ }
+}