aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala30
2 files changed, 22 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 99f51ba5b6..ba379d358d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -104,7 +104,7 @@ case class TungstenAggregate(
} else {
// This is a grouped aggregate and the input iterator is empty,
// so return an empty iterator.
- Iterator[UnsafeRow]()
+ Iterator.empty
}
} else {
aggregationIterator.start(parentIterator)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index af7e0fcedb..26fdbc83ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -357,18 +357,30 @@ class TungstenAggregationIterator(
// sort-based aggregation (by calling switchToSortBasedAggregation).
private def processInputs(): Unit = {
assert(inputIter != null, "attempted to process input when iterator was null")
- while (!sortBased && inputIter.hasNext) {
- val newInput = inputIter.next()
- numInputRows += 1
- val groupingKey = groupProjection.apply(newInput)
+ if (groupingExpressions.isEmpty) {
+ // If there is no grouping expressions, we can just reuse the same buffer over and over again.
+ // Note that it would be better to eliminate the hash map entirely in the future.
+ val groupingKey = groupProjection.apply(null)
val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
- if (buffer == null) {
- // buffer == null means that we could not allocate more memory.
- // Now, we need to spill the map and switch to sort-based aggregation.
- switchToSortBasedAggregation(groupingKey, newInput)
- } else {
+ while (inputIter.hasNext) {
+ val newInput = inputIter.next()
+ numInputRows += 1
processRow(buffer, newInput)
}
+ } else {
+ while (!sortBased && inputIter.hasNext) {
+ val newInput = inputIter.next()
+ numInputRows += 1
+ val groupingKey = groupProjection.apply(newInput)
+ val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
+ if (buffer == null) {
+ // buffer == null means that we could not allocate more memory.
+ // Now, we need to spill the map and switch to sort-based aggregation.
+ switchToSortBasedAggregation(groupingKey, newInput)
+ } else {
+ processRow(buffer, newInput)
+ }
+ }
}
}