aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-11-03 09:34:51 -0700
committerYin Huai <yhuai@databricks.com>2016-11-03 09:34:51 -0700
commit27daf6bcde782ed3e0f0d951c90c8040fd47e985 (patch)
treed56b006e8f954af73ee46d5084e1c5b855334cfb /sql/core
parent66a99f4a411ee7dc94ff1070a8fd6865fd004093 (diff)
downloadspark-27daf6bcde782ed3e0f0d951c90c8040fd47e985.tar.gz
spark-27daf6bcde782ed3e0f0d951c90c8040fd47e985.tar.bz2
spark-27daf6bcde782ed3e0f0d951c90c8040fd47e985.zip
[SPARK-17949][SQL] A JVM object based aggregate operator
## What changes were proposed in this pull request? This PR adds a new hash-based aggregate operator named `ObjectHashAggregateExec` that supports `TypedImperativeAggregate`, which may use arbitrary Java objects as aggregation states. Please refer to the [design doc](https://issues.apache.org/jira/secure/attachment/12834260/%5BDesign%20Doc%5D%20Support%20for%20Arbitrary%20Aggregation%20States.pdf) attached in [SPARK-17949](https://issues.apache.org/jira/browse/SPARK-17949) for more details about it. The major benefit of this operator is better performance when evaluating `TypedImperativeAggregate` functions, especially when there are relatively few distinct groups. Functions like Hive UDAFs, `collect_list`, and `collect_set` may also benefit from this after being migrated to `TypedImperativeAggregate`. The following feature flag is introduced to enable or disable the new aggregate operator: - Name: `spark.sql.execution.useObjectHashAggregateExec` - Default value: `true` We can also configure the fallback threshold using the following SQL operation: - Name: `spark.sql.objectHashAggregate.sortBased.fallbackThreshold` - Default value: 128 Fallback to sort-based aggregation when more than 128 distinct groups are accumulated in the aggregation hash map. This number is intentionally made small to avoid GC problems since aggregation buffers of this operator may contain arbitrary Java objects. This may be improved by implementing size tracking for this operator, but that can be done in a separate PR. Code generation and size tracking are planned to be implemented in follow-up PRs. ## Benchmark results ### `ObjectHashAggregateExec` vs `SortAggregateExec` The first benchmark compares `ObjectHashAggregateExec` and `SortAggregateExec` by evaluating `typed_count`, a testing `TypedImperativeAggregate` version of the SQL `count` function. ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz object agg v.s. sort agg: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ sort agg w/ group by 31251 / 31908 3.4 298.0 1.0X object agg w/ group by w/o fallback 6903 / 7141 15.2 65.8 4.5X object agg w/ group by w/ fallback 20945 / 21613 5.0 199.7 1.5X sort agg w/o group by 4734 / 5463 22.1 45.2 6.6X object agg w/o group by w/o fallback 4310 / 4529 24.3 41.1 7.3X ``` The next benchmark compares `ObjectHashAggregateExec` and `SortAggregateExec` by evaluating the Spark native version of `percentile_approx`. Note that `percentile_approx` is so heavy an aggregate function that the bottleneck of the benchmark is evaluating the aggregate function itself rather than the aggregate operator since I couldn't run a large scale benchmark on my laptop. That's why the results are so close and looks counter-intuitive (aggregation with grouping is even faster than that aggregation without grouping). ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz object agg v.s. sort agg: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ sort agg w/ group by 3418 / 3530 0.6 1630.0 1.0X object agg w/ group by w/o fallback 3210 / 3314 0.7 1530.7 1.1X object agg w/ group by w/ fallback 3419 / 3511 0.6 1630.1 1.0X sort agg w/o group by 4336 / 4499 0.5 2067.3 0.8X object agg w/o group by w/o fallback 4271 / 4372 0.5 2036.7 0.8X ``` ### Hive UDAF vs Spark AF This benchmark compares the following two kinds of aggregate functions: - "hive udaf": Hive implementation of `percentile_approx`, without partial aggregation supports, evaluated using `SortAggregateExec`. - "spark af": Spark native implementation of `percentile_approx`, with partial aggregation support, evaluated using `ObjectHashAggregateExec` The performance differences are mostly due to faster implementation and partial aggregation support in the Spark native version of `percentile_approx`. This benchmark basically shows the performance differences between the worst case, where an aggregate function without partial aggregation support is evaluated using `SortAggregateExec`, and the best case, where a `TypedImperativeAggregate` with partial aggregation support is evaluated using `ObjectHashAggregateExec`. ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ hive udaf w/o group by 5326 / 5408 0.0 81264.2 1.0X spark af w/o group by 93 / 111 0.7 1415.6 57.4X hive udaf w/ group by 3804 / 3946 0.0 58050.1 1.4X spark af w/ group by w/o fallback 71 / 90 0.9 1085.7 74.8X spark af w/ group by w/ fallback 98 / 111 0.7 1501.6 54.1X ``` ### Real world benchmark We also did a relatively large benchmark using a real world query involving `percentile_approx`: - Hive UDAF implementation, sort-based aggregation, w/o partial aggregation support 24.77 minutes - Native implementation, sort-based aggregation, w/ partial aggregation support 4.64 minutes - Native implementation, object hash aggregator, w/ partial aggregation support 1.80 minutes ## How was this patch tested? New unit tests and randomized test cases are added in `ObjectAggregateFunctionSuite`. Author: Cheng Lian <lian@databricks.com> Closes #15590 from liancheng/obj-hash-agg.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala323
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala110
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala155
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala141
7 files changed, 777 insertions, 11 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index 4fbb9d554c..3c8ef1ad84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec}
+import org.apache.spark.sql.internal.SQLConf
/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
@@ -66,14 +67,28 @@ object AggUtils {
resultExpressions = resultExpressions,
child = child)
} else {
- SortAggregateExec(
- requiredChildDistributionExpressions = requiredChildDistributionExpressions,
- groupingExpressions = groupingExpressions,
- aggregateExpressions = aggregateExpressions,
- aggregateAttributes = aggregateAttributes,
- initialInputBufferOffset = initialInputBufferOffset,
- resultExpressions = resultExpressions,
- child = child)
+ val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation
+ val useObjectHash = ObjectHashAggregateExec.supportsAggregate(aggregateExpressions)
+
+ if (objectHashEnabled && useObjectHash) {
+ ObjectHashAggregateExec(
+ requiredChildDistributionExpressions = requiredChildDistributionExpressions,
+ groupingExpressions = groupingExpressions,
+ aggregateExpressions = aggregateExpressions,
+ aggregateAttributes = aggregateAttributes,
+ initialInputBufferOffset = initialInputBufferOffset,
+ resultExpressions = resultExpressions,
+ child = child)
+ } else {
+ SortAggregateExec(
+ requiredChildDistributionExpressions = requiredChildDistributionExpressions,
+ groupingExpressions = groupingExpressions,
+ aggregateExpressions = aggregateExpressions,
+ aggregateAttributes = aggregateAttributes,
+ initialInputBufferOffset = initialInputBufferOffset,
+ resultExpressions = resultExpressions,
+ child = child)
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
new file mode 100644
index 0000000000..3c7b9ee317
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, GenerateOrdering}
+import org.apache.spark.sql.execution.UnsafeKVExternalSorter
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.KVIterator
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
+
+class ObjectAggregationIterator(
+ outputAttributes: Seq[Attribute],
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression],
+ aggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
+ originalInputAttributes: Seq[Attribute],
+ inputRows: Iterator[InternalRow],
+ fallbackCountThreshold: Int)
+ extends AggregationIterator(
+ groupingExpressions,
+ originalInputAttributes,
+ aggregateExpressions,
+ aggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection) with Logging {
+
+ // Indicates whether we have fallen back to sort-based aggregation or not.
+ private[this] var sortBased: Boolean = false
+
+ private[this] var aggBufferIterator: Iterator[AggregationBufferEntry] = _
+
+ // Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers
+ private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = {
+ val newExpressions = aggregateExpressions.map {
+ case agg @ AggregateExpression(_, Partial, _, _) =>
+ agg.copy(mode = PartialMerge)
+ case agg @ AggregateExpression(_, Complete, _, _) =>
+ agg.copy(mode = Final)
+ case other => other
+ }
+ val newFunctions = initializeAggregateFunctions(newExpressions, 0)
+ val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes)
+ generateProcessRow(newExpressions, newFunctions, newInputAttributes)
+ }
+
+ // A safe projection used to do deep clone of input rows to prevent false sharing.
+ private[this] val safeProjection: Projection =
+ FromUnsafeProjection(outputAttributes.map(_.dataType))
+
+ /**
+ * Start processing input rows.
+ */
+ processInputs()
+
+ override final def hasNext: Boolean = {
+ aggBufferIterator.hasNext
+ }
+
+ override final def next(): UnsafeRow = {
+ val entry = aggBufferIterator.next()
+ generateOutput(entry.groupingKey, entry.aggregationBuffer)
+ }
+
+ /**
+ * Generate an output row when there is no input and there is no grouping expression.
+ */
+ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
+ if (groupingExpressions.isEmpty) {
+ val defaultAggregationBuffer = createNewAggregationBuffer()
+ generateOutput(UnsafeRow.createFromByteArray(0, 0), defaultAggregationBuffer)
+ } else {
+ throw new IllegalStateException(
+ "This method should not be called when groupingExpressions is not empty.")
+ }
+ }
+
+ // Creates a new aggregation buffer and initializes buffer values. This function should only be
+ // called under two cases:
+ //
+ // - when creating aggregation buffer for a new group in the hash map, and
+ // - when creating the re-used buffer for sort-based aggregation
+ private def createNewAggregationBuffer(): SpecificInternalRow = {
+ val bufferFieldTypes = aggregateFunctions.flatMap(_.aggBufferAttributes.map(_.dataType))
+ val buffer = new SpecificInternalRow(bufferFieldTypes)
+ initAggregationBuffer(buffer)
+ buffer
+ }
+
+ private def initAggregationBuffer(buffer: SpecificInternalRow): Unit = {
+ // Initializes declarative aggregates' buffer values
+ expressionAggInitialProjection.target(buffer)(EmptyRow)
+ // Initializes imperative aggregates' buffer values
+ aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer))
+ }
+
+ private def getAggregationBufferByKey(
+ hashMap: ObjectAggregationMap, groupingKey: UnsafeRow): InternalRow = {
+ var aggBuffer = hashMap.getAggregationBuffer(groupingKey)
+
+ if (aggBuffer == null) {
+ aggBuffer = createNewAggregationBuffer()
+ hashMap.putAggregationBuffer(groupingKey.copy(), aggBuffer)
+ }
+
+ aggBuffer
+ }
+
+ // This function is used to read and process input rows. When processing input rows, it first uses
+ // hash-based aggregation by putting groups and their buffers in `hashMap`. If `hashMap` grows too
+ // large, it sorts the contents, spills them to disk, and creates a new map. At last, all sorted
+ // spills are merged together for sort-based aggregation.
+ private def processInputs(): Unit = {
+ // In-memory map to store aggregation buffer for hash-based aggregation.
+ val hashMap = new ObjectAggregationMap()
+
+ // If in-memory map is unable to stores all aggregation buffer, fallback to sort-based
+ // aggregation backed by sorted physical storage.
+ var sortBasedAggregationStore: SortBasedAggregator = null
+
+ if (groupingExpressions.isEmpty) {
+ // If there is no grouping expressions, we can just reuse the same buffer over and over again.
+ val groupingKey = groupingProjection.apply(null)
+ val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey)
+ while (inputRows.hasNext) {
+ val newInput = safeProjection(inputRows.next())
+ processRow(buffer, newInput)
+ }
+ } else {
+ while (inputRows.hasNext && !sortBased) {
+ val newInput = safeProjection(inputRows.next())
+ val groupingKey = groupingProjection.apply(newInput)
+ val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey)
+ processRow(buffer, newInput)
+
+ // The the hash map gets too large, makes a sorted spill and clear the map.
+ if (hashMap.size >= fallbackCountThreshold) {
+ logInfo(
+ s"Aggregation hash map reaches threshold " +
+ s"capacity ($fallbackCountThreshold entries), spilling and falling back to sort" +
+ s" based aggregation. You may change the threshold by adjust option " +
+ SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key
+ )
+
+ // Falls back to sort-based aggregation
+ sortBased = true
+
+ }
+ }
+
+ if (sortBased) {
+ val sortIteratorFromHashMap = hashMap
+ .dumpToExternalSorter(groupingAttributes, aggregateFunctions)
+ .sortedIterator()
+ sortBasedAggregationStore = new SortBasedAggregator(
+ sortIteratorFromHashMap,
+ StructType.fromAttributes(originalInputAttributes),
+ StructType.fromAttributes(groupingAttributes),
+ processRow,
+ mergeAggregationBuffers,
+ createNewAggregationBuffer())
+
+ while (inputRows.hasNext) {
+ // NOTE: The input row is always UnsafeRow
+ val unsafeInputRow = inputRows.next().asInstanceOf[UnsafeRow]
+ val groupingKey = groupingProjection.apply(unsafeInputRow)
+ sortBasedAggregationStore.addInput(groupingKey, unsafeInputRow)
+ }
+ }
+ }
+
+ if (sortBased) {
+ aggBufferIterator = sortBasedAggregationStore.destructiveIterator()
+ } else {
+ aggBufferIterator = hashMap.iterator
+ }
+ }
+}
+
+/**
+ * A class used to handle sort-based aggregation, used together with [[ObjectHashAggregateExec]].
+ *
+ * @param initialAggBufferIterator iterator that points to sorted input aggregation buffers
+ * @param inputSchema The schema of input row
+ * @param groupingSchema The schema of grouping key
+ * @param processRow Function to update the aggregation buffer with input rows
+ * @param mergeAggregationBuffers Function used to merge the input aggregation buffers into existing
+ * aggregation buffers
+ * @param makeEmptyAggregationBuffer Creates an empty aggregation buffer
+ *
+ * @todo Try to eliminate this class by refactor and reuse code paths in [[SortAggregateExec]].
+ */
+class SortBasedAggregator(
+ initialAggBufferIterator: KVIterator[UnsafeRow, UnsafeRow],
+ inputSchema: StructType,
+ groupingSchema: StructType,
+ processRow: (InternalRow, InternalRow) => Unit,
+ mergeAggregationBuffers: (InternalRow, InternalRow) => Unit,
+ makeEmptyAggregationBuffer: => InternalRow) {
+
+ // external sorter to sort the input (grouping key + input row) with grouping key.
+ private val inputSorter = createExternalSorterForInput()
+ private val groupingKeyOrdering: BaseOrdering = GenerateOrdering.create(groupingSchema)
+
+ def addInput(groupingKey: UnsafeRow, inputRow: UnsafeRow): Unit = {
+ inputSorter.insertKV(groupingKey, inputRow)
+ }
+
+ /**
+ * Returns a destructive iterator of AggregationBufferEntry.
+ * Notice: it is illegal to call any method after `destructiveIterator()` has been called.
+ */
+ def destructiveIterator(): Iterator[AggregationBufferEntry] = {
+ new Iterator[AggregationBufferEntry] {
+ val inputIterator = inputSorter.sortedIterator()
+ var hasNextInput: Boolean = inputIterator.next()
+ var hasNextAggBuffer: Boolean = initialAggBufferIterator.next()
+ private var result: AggregationBufferEntry = _
+ private var groupingKey: UnsafeRow = _
+
+ override def hasNext(): Boolean = {
+ result != null || findNextSortedGroup()
+ }
+
+ override def next(): AggregationBufferEntry = {
+ val returnResult = result
+ result = null
+ returnResult
+ }
+
+ // Two-way merges initialAggBufferIterator and inputIterator
+ private def findNextSortedGroup(): Boolean = {
+ if (hasNextInput || hasNextAggBuffer) {
+ // Find smaller key of the initialAggBufferIterator and initialAggBufferIterator
+ groupingKey = findGroupingKey()
+ result = new AggregationBufferEntry(groupingKey, makeEmptyAggregationBuffer)
+
+ // Firstly, update the aggregation buffer with input rows.
+ while (hasNextInput &&
+ groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) {
+ processRow(result.aggregationBuffer, inputIterator.getValue)
+ hasNextInput = inputIterator.next()
+ }
+
+ // Secondly, merge the aggregation buffer with existing aggregation buffers.
+ // NOTE: the ordering of these two while-block matter, mergeAggregationBuffer() should
+ // be called after calling processRow.
+ while (hasNextAggBuffer &&
+ groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) {
+ mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue)
+ hasNextAggBuffer = initialAggBufferIterator.next()
+ }
+
+ true
+ } else {
+ false
+ }
+ }
+
+ private def findGroupingKey(): UnsafeRow = {
+ var newGroupingKey: UnsafeRow = null
+ if (!hasNextInput) {
+ newGroupingKey = initialAggBufferIterator.getKey
+ } else if (!hasNextAggBuffer) {
+ newGroupingKey = inputIterator.getKey
+ } else {
+ val compareResult =
+ groupingKeyOrdering.compare(inputIterator.getKey, initialAggBufferIterator.getKey)
+ if (compareResult <= 0) {
+ newGroupingKey = inputIterator.getKey
+ } else {
+ newGroupingKey = initialAggBufferIterator.getKey
+ }
+ }
+
+ if (groupingKey == null) {
+ groupingKey = newGroupingKey.copy()
+ } else {
+ groupingKey.copyFrom(newGroupingKey)
+ }
+ groupingKey
+ }
+ }
+ }
+
+ private def createExternalSorterForInput(): UnsafeKVExternalSorter = {
+ new UnsafeKVExternalSorter(
+ groupingSchema,
+ inputSchema,
+ SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
+ TaskContext.get().taskMemoryManager().pageSizeBytes,
+ SparkEnv.get.conf.getLong(
+ "spark.shuffle.spill.numElementsForceSpillThreshold",
+ UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
+ null
+ )
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala
new file mode 100644
index 0000000000..f2d4f6c6eb
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import java.{util => ju}
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, TypedImperativeAggregate}
+import org.apache.spark.sql.execution.UnsafeKVExternalSorter
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
+
+/**
+ * An aggregation map that supports using safe `SpecificInternalRow`s aggregation buffers, so that
+ * we can support storing arbitrary Java objects as aggregate function states in the aggregation
+ * buffers. This class is only used together with [[ObjectHashAggregateExec]].
+ */
+class ObjectAggregationMap() {
+ private[this] val hashMap = new ju.LinkedHashMap[UnsafeRow, InternalRow]
+
+ def getAggregationBuffer(groupingKey: UnsafeRow): InternalRow = {
+ hashMap.get(groupingKey)
+ }
+
+ def putAggregationBuffer(groupingKey: UnsafeRow, aggBuffer: InternalRow): Unit = {
+ hashMap.put(groupingKey, aggBuffer)
+ }
+
+ def size: Int = hashMap.size()
+
+ def iterator: Iterator[AggregationBufferEntry] = {
+ val iter = hashMap.entrySet().iterator()
+ new Iterator[AggregationBufferEntry] {
+
+ override def hasNext: Boolean = {
+ iter.hasNext
+ }
+ override def next(): AggregationBufferEntry = {
+ val entry = iter.next()
+ new AggregationBufferEntry(entry.getKey, entry.getValue)
+ }
+ }
+ }
+
+ /**
+ * Dumps all entries into a newly created external sorter, clears the hash map, and returns the
+ * external sorter.
+ */
+ def dumpToExternalSorter(
+ groupingAttributes: Seq[Attribute],
+ aggregateFunctions: Seq[AggregateFunction]): UnsafeKVExternalSorter = {
+ val aggBufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes)
+ val sorter = new UnsafeKVExternalSorter(
+ StructType.fromAttributes(groupingAttributes),
+ StructType.fromAttributes(aggBufferAttributes),
+ SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
+ TaskContext.get().taskMemoryManager().pageSizeBytes,
+ SparkEnv.get.conf.getLong(
+ "spark.shuffle.spill.numElementsForceSpillThreshold",
+ UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
+ null
+ )
+
+ val mapIterator = iterator
+ val unsafeAggBufferProjection =
+ UnsafeProjection.create(aggBufferAttributes.map(_.dataType).toArray)
+
+ while (mapIterator.hasNext) {
+ val entry = mapIterator.next()
+ aggregateFunctions.foreach {
+ case agg: TypedImperativeAggregate[_] =>
+ agg.serializeAggregateBufferInPlace(entry.aggregationBuffer)
+ case _ =>
+ }
+
+ sorter.insertKV(
+ entry.groupingKey,
+ unsafeAggBufferProjection(entry.aggregationBuffer)
+ )
+ }
+
+ hashMap.clear()
+ sorter
+ }
+
+ def clear(): Unit = {
+ hashMap.clear()
+ }
+}
+
+// Stores the grouping key and aggregation buffer
+class AggregationBufferEntry(var groupingKey: UnsafeRow, var aggregationBuffer: InternalRow)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
new file mode 100644
index 0000000000..3fcb7ec9a6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.util.Utils
+
+/**
+ * A hash-based aggregate operator that supports [[TypedImperativeAggregate]] functions that may
+ * use arbitrary JVM objects as aggregation states.
+ *
+ * Similar to [[HashAggregateExec]], this operator also falls back to sort-based aggregation when
+ * the size of the internal hash map exceeds the threshold. The differences are:
+ *
+ * - It uses safe rows as aggregation buffer since it must support JVM objects as aggregation
+ * states.
+ *
+ * - It tracks entry count of the hash map instead of byte size to decide when we should fall back.
+ * This is because it's hard to estimate the accurate size of arbitrary JVM objects in a
+ * lightweight way.
+ *
+ * - Whenever fallen back to sort-based aggregation, this operator feeds all of the rest input rows
+ * into external sorters instead of building more hash map(s) as what [[HashAggregateExec]] does.
+ * This is because having too many JVM object aggregation states floating there can be dangerous
+ * for GC.
+ *
+ * - CodeGen is not supported yet.
+ *
+ * This operator may be turned off by setting the following SQL configuration to `false`:
+ * {{{
+ * spark.sql.execution.useObjectHashAggregateExec
+ * }}}
+ * The fallback threshold can be configured by tuning:
+ * {{{
+ * spark.sql.objectHashAggregate.sortBased.fallbackThreshold
+ * }}}
+ */
+case class ObjectHashAggregateExec(
+ requiredChildDistributionExpressions: Option[Seq[Expression]],
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression],
+ aggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryExecNode {
+
+ private[this] val aggregateBufferAttributes = {
+ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+ }
+
+ override lazy val allAttributes: AttributeSeq =
+ child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
+ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+
+ override lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")
+ )
+
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ override def producedAttributes: AttributeSet =
+ AttributeSet(aggregateAttributes) ++
+ AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
+ AttributeSet(aggregateBufferAttributes)
+
+ override def requiredChildDistribution: List[Distribution] = {
+ requiredChildDistributionExpressions match {
+ case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
+ case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
+ case None => UnspecifiedDistribution :: Nil
+ }
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ val numOutputRows = longMetric("numOutputRows")
+ val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold
+
+ child.execute().mapPartitionsInternal { iter =>
+ val hasInput = iter.hasNext
+ if (!hasInput && groupingExpressions.nonEmpty) {
+ // This is a grouped aggregate and the input kvIterator is empty,
+ // so return an empty kvIterator.
+ Iterator.empty
+ } else {
+ val aggregationIterator =
+ new ObjectAggregationIterator(
+ child.output,
+ groupingExpressions,
+ aggregateExpressions,
+ aggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ (expressions, inputSchema) =>
+ newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
+ child.output,
+ iter,
+ fallbackCountThreshold)
+ if (!hasInput && groupingExpressions.isEmpty) {
+ numOutputRows += 1
+ Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
+ } else {
+ aggregationIterator
+ }
+ }
+ }
+ }
+
+ override def verboseString: String = toString(verbose = true)
+
+ override def simpleString: String = toString(verbose = false)
+
+ private def toString(verbose: Boolean): String = {
+ val allAggregateExpressions = aggregateExpressions
+ val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]")
+ val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]")
+ val outputString = Utils.truncatedString(output, "[", ", ", "]")
+ if (verbose) {
+ s"ObjectHashAggregate(keys=$keyString, functions=$functionString, output=$outputString)"
+ } else {
+ s"ObjectHashAggregate(keys=$keyString, functions=$functionString)"
+ }
+ }
+}
+
+object ObjectHashAggregateExec {
+ def supportsAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = {
+ aggregateExpressions.map(_.aggregateFunction).exists {
+ case _: TypedImperativeAggregate[_] => true
+ case _ => false
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 7b8ed65054..71f3a67d0d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -526,6 +526,24 @@ object SQLConf {
.stringConf
.createWithDefault(classOf[ManifestFileCommitProtocol].getName)
+ val OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD =
+ SQLConfigBuilder("spark.sql.objectHashAggregate.sortBased.fallbackThreshold")
+ .internal()
+ .doc("In the case of ObjectHashAggregateExec, when the size of the in-memory hash map " +
+ "grows too large, we will fall back to sort-based aggregation. This option sets a row " +
+ "count threshold for the size of the hash map.")
+ .intConf
+ // We are trying to be conservative and use a relatively small default count threshold here
+ // since the state object of some TypedImperativeAggregate function can be quite large (e.g.
+ // percentile_approx).
+ .createWithDefault(128)
+
+ val USE_OBJECT_HASH_AGG = SQLConfigBuilder("spark.sql.execution.useObjectHashAggregateExec")
+ .internal()
+ .doc("Decides if we use ObjectHashAggregateExec")
+ .booleanConf
+ .createWithDefault(true)
+
val FILE_SINK_LOG_DELETION = SQLConfigBuilder("spark.sql.streaming.fileSink.log.deletion")
.internal()
.doc("Whether to delete the expired log files in file stream sink.")
@@ -769,6 +787,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP)
+ def useObjectHashAggregation: Boolean = getConf(USE_OBJECT_HASH_AGG)
+
+ def objectAggSortBasedFallbackThreshold: Int = getConf(OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD)
+
def variableSubstituteEnabled: Boolean = getConf(VARIABLE_SUBSTITUTE_ENABLED)
def variableSubstituteDepth: Int = getConf(VARIABLE_SUBSTITUTE_DEPTH)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
index ffa26f1f82..07599152e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, SpecificInternalRow}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
-import org.apache.spark.sql.execution.aggregate.SortAggregateExec
+import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
@@ -87,11 +87,11 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
test("dataframe aggregate with object aggregate buffer, should not use HashAggregate") {
val df = data.toDF("a", "b")
- val max = new TypedMax($"a".expr)
+ val max = TypedMax($"a".expr)
// Always uses SortAggregateExec
val sparkPlan = df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan
- assert(sparkPlan.isInstanceOf[SortAggregateExec])
+ assert(!sparkPlan.isInstanceOf[HashAggregateExec])
}
test("dataframe aggregate with object aggregate buffer, no group by") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala
new file mode 100644
index 0000000000..bc9cb6ec2e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import java.util.Properties
+
+import scala.collection.mutable
+
+import org.apache.spark._
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.unsafe.KVIterator
+
+class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkContext {
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ val conf = new SparkConf()
+ sc = new SparkContext("local[2, 4]", "test", conf)
+ val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0)
+ TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, new Properties, null))
+ }
+
+ override def afterAll(): Unit = TaskContext.unset()
+
+ private val rand = new java.util.Random()
+
+ // In this test, the aggregator is XOR checksum.
+ test("merge input kv iterator and aggregation buffer iterator") {
+
+ val inputSchema = StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType)))
+ val groupingSchema = StructType(Seq(StructField("b", IntegerType)))
+
+ // Schema: a: Int, b: Int
+ val inputRow: UnsafeRow = createUnsafeRow(2)
+
+ // Schema: group: Int
+ val group: UnsafeRow = createUnsafeRow(1)
+
+ val expected = new mutable.HashMap[Int, Int]()
+ val hashMap = new ObjectAggregationMap
+ (0 to 5000).foreach { _ =>
+ randomKV(inputRow, group)
+
+ // XOR aggregate on first column of input row
+ expected.put(group.getInt(0), expected.getOrElse(group.getInt(0), 0) ^ inputRow.getInt(0))
+ if (hashMap.getAggregationBuffer(group) == null) {
+ hashMap.putAggregationBuffer(group.copy, createNewAggregationBuffer())
+ }
+ updateInputRow(hashMap.getAggregationBuffer(group), inputRow)
+ }
+
+ val store = new SortBasedAggregator(
+ createSortedAggBufferIterator(hashMap),
+ inputSchema,
+ groupingSchema,
+ updateInputRow,
+ mergeAggBuffer,
+ createNewAggregationBuffer)
+
+ (5000 to 100000).foreach { _ =>
+ randomKV(inputRow, group)
+ // XOR aggregate on first column of input row
+ expected.put(group.getInt(0), expected.getOrElse(group.getInt(0), 0) ^ inputRow.getInt(0))
+ store.addInput(group, inputRow)
+ }
+
+ val iter = store.destructiveIterator()
+ while(iter.hasNext) {
+ val agg = iter.next()
+ assert(agg.aggregationBuffer.getInt(0) == expected(agg.groupingKey.getInt(0)))
+ }
+ }
+
+ private def createNewAggregationBuffer(): InternalRow = {
+ val buffer = createUnsafeRow(1)
+ buffer.setInt(0, 0)
+ buffer
+ }
+
+ private def updateInputRow: (InternalRow, InternalRow) => Unit = {
+ (buffer: InternalRow, input: InternalRow) => {
+ buffer.setInt(0, buffer.getInt(0) ^ input.getInt(0))
+ }
+ }
+
+ private def mergeAggBuffer: (InternalRow, InternalRow) => Unit = updateInputRow
+
+ private def createUnsafeRow(numOfField: Int): UnsafeRow = {
+ val buffer: Array[Byte] = new Array(1024)
+ val row: UnsafeRow = new UnsafeRow(numOfField)
+ row.pointTo(buffer, 1024)
+ row
+ }
+
+ private def randomKV(inputRow: UnsafeRow, group: UnsafeRow): Unit = {
+ inputRow.setInt(0, rand.nextInt(100000))
+ inputRow.setInt(1, rand.nextInt(10000))
+ group.setInt(0, inputRow.getInt(1) % 100)
+ }
+
+ def createSortedAggBufferIterator(
+ hashMap: ObjectAggregationMap): KVIterator[UnsafeRow, UnsafeRow] = {
+
+ val sortedIterator = hashMap.iterator.toList.sortBy(_.groupingKey.getInt(0)).iterator
+ new KVIterator[UnsafeRow, UnsafeRow] {
+ var key: UnsafeRow = null
+ var value: UnsafeRow = null
+ override def next: Boolean = {
+ if (sortedIterator.hasNext) {
+ val kv = sortedIterator.next()
+ key = kv.groupingKey
+ value = kv.aggregationBuffer.asInstanceOf[UnsafeRow]
+ true
+ } else {
+ false
+ }
+ }
+ override def getKey(): UnsafeRow = key
+ override def getValue(): UnsafeRow = value
+ override def close(): Unit = Unit
+ }
+ }
+}