From 595012ea8b9c6afcc2fc024d5a5e198df765bd75 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 11 Oct 2015 18:11:08 -0700 Subject: [SPARK-11053] Remove use of KVIterator in SortBasedAggregationIterator SortBasedAggregationIterator uses a KVIterator interface in order to process input rows as key-value pairs, but this use of KVIterator is unnecessary, slightly complicates the code, and might hurt performance. This patch refactors this code to remove the use of this extra layer of iterator wrapping and simplifies other parts of the code in the process. Author: Josh Rosen Closes #9066 from JoshRosen/sort-iterator-cleanup. --- .../execution/aggregate/AggregationIterator.scala | 83 -------------------- .../execution/aggregate/SortBasedAggregate.scala | 20 +++-- .../aggregate/SortBasedAggregationIterator.scala | 89 +++++----------------- 3 files changed, 33 insertions(+), 159 deletions(-) (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 5f7341e88c..8e0fbd109b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -21,7 +21,6 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.unsafe.KVIterator import scala.collection.mutable.ArrayBuffer @@ -412,85 +411,3 @@ abstract class AggregationIterator( */ protected def newBuffer: MutableRow } - -object AggregationIterator { - def kvIterator( - groupingExpressions: Seq[NamedExpression], - newProjection: (Seq[Expression], Seq[Attribute]) => Projection, - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]): KVIterator[InternalRow, InternalRow] = { - new KVIterator[InternalRow, InternalRow] { - private[this] val groupingKeyGenerator = newProjection(groupingExpressions, inputAttributes) - - private[this] var groupingKey: InternalRow = _ - - private[this] var value: InternalRow = _ - - override def next(): Boolean = { - if (inputIter.hasNext) { - // Read the next input row. - val inputRow = inputIter.next() - // Get groupingKey based on groupingExpressions. - groupingKey = groupingKeyGenerator(inputRow) - // The value is the inputRow. - value = inputRow - true - } else { - false - } - } - - override def getKey(): InternalRow = { - groupingKey - } - - override def getValue(): InternalRow = { - value - } - - override def close(): Unit = { - // Do nothing - } - } - } - - def unsafeKVIterator( - groupingExpressions: Seq[NamedExpression], - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]): KVIterator[UnsafeRow, InternalRow] = { - new KVIterator[UnsafeRow, InternalRow] { - private[this] val groupingKeyGenerator = - UnsafeProjection.create(groupingExpressions, inputAttributes) - - private[this] var groupingKey: UnsafeRow = _ - - private[this] var value: InternalRow = _ - - override def next(): Boolean = { - if (inputIter.hasNext) { - // Read the next input row. - val inputRow = inputIter.next() - // Get groupingKey based on groupingExpressions. - groupingKey = groupingKeyGenerator.apply(inputRow) - // The value is the inputRow. - value = inputRow - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): InternalRow = { - value - } - - override def close(): Unit = { - // Do nothing - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index f4c14a9b35..4d37106e00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} -import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType case class SortBasedAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -79,18 +78,23 @@ case class SortBasedAggregate( // so return an empty iterator. Iterator[InternalRow]() } else { - val outputIter = SortBasedAggregationIterator.createFromInputIterator( - groupingExpressions, + val groupingKeyProjection = if (UnsafeProjection.canSupport(groupingExpressions)) { + UnsafeProjection.create(groupingExpressions, child.output) + } else { + newMutableProjection(groupingExpressions, child.output)() + } + val outputIter = new SortBasedAggregationIterator( + groupingKeyProjection, + groupingExpressions.map(_.toAttribute), + child.output, + iter, nonCompleteAggregateExpressions, nonCompleteAggregateAttributes, completeAggregateExpressions, completeAggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection _, - newProjection _, - child.output, - iter, + newMutableProjection, outputsUnsafeRows, numInputRows, numOutputRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index a9e5d175bf..64c673064f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -21,16 +21,16 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.unsafe.KVIterator /** * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been * sorted by values of [[groupingKeyAttributes]]. */ class SortBasedAggregationIterator( + groupingKeyProjection: InternalRow => InternalRow, groupingKeyAttributes: Seq[Attribute], valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[InternalRow, InternalRow], + inputIterator: Iterator[InternalRow], nonCompleteAggregateExpressions: Seq[AggregateExpression2], nonCompleteAggregateAttributes: Seq[Attribute], completeAggregateExpressions: Seq[AggregateExpression2], @@ -90,6 +90,22 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + protected def initialize(): Unit = { + if (inputIterator.hasNext) { + initializeBuffer(sortBasedAggregationBuffer) + val inputRow = inputIterator.next() + nextGroupingKey = groupingKeyProjection(inputRow).copy() + firstRowInNextGroup = inputRow.copy() + numInputRows += 1 + sortedInputHasNewGroup = true + } else { + // This inputIter is empty. + sortedInputHasNewGroup = false + } + } + + initialize() + /** Processes rows in the current group. It will stop when it find a new group. */ protected def processCurrentSortedGroup(): Unit = { currentGroupingKey = nextGroupingKey @@ -101,18 +117,15 @@ class SortBasedAggregationIterator( // The search will stop when we see the next group or there is no // input row left in the iter. - var hasNext = inputKVIterator.next() - while (!findNextPartition && hasNext) { + while (!findNextPartition && inputIterator.hasNext) { // Get the grouping key. - val groupingKey = inputKVIterator.getKey - val currentRow = inputKVIterator.getValue + val currentRow = inputIterator.next() + val groupingKey = groupingKeyProjection(currentRow) numInputRows += 1 // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { processRow(sortBasedAggregationBuffer, currentRow) - - hasNext = inputKVIterator.next() } else { // We find a new group. findNextPartition = true @@ -149,68 +162,8 @@ class SortBasedAggregationIterator( } } - protected def initialize(): Unit = { - if (inputKVIterator.next()) { - initializeBuffer(sortBasedAggregationBuffer) - - nextGroupingKey = inputKVIterator.getKey().copy() - firstRowInNextGroup = inputKVIterator.getValue().copy() - numInputRows += 1 - sortedInputHasNewGroup = true - } else { - // This inputIter is empty. - sortedInputHasNewGroup = false - } - } - - initialize() - def outputForEmptyGroupingKeyWithoutInput(): InternalRow = { initializeBuffer(sortBasedAggregationBuffer) generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer) } } - -object SortBasedAggregationIterator { - // scalastyle:off - def createFromInputIterator( - groupingExprs: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - newProjection: (Seq[Expression], Seq[Attribute]) => Projection, - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], - outputsUnsafeRows: Boolean, - numInputRows: LongSQLMetric, - numOutputRows: LongSQLMetric): SortBasedAggregationIterator = { - val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) { - AggregationIterator.unsafeKVIterator( - groupingExprs, - inputAttributes, - inputIter).asInstanceOf[KVIterator[InternalRow, InternalRow]] - } else { - AggregationIterator.kvIterator(groupingExprs, newProjection, inputAttributes, inputIter) - } - - new SortBasedAggregationIterator( - groupingExprs.map(_.toAttribute), - inputAttributes, - kvIterator, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows, - numInputRows, - numOutputRows) - } - // scalastyle:on -} -- cgit v1.2.3