aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-10-11 18:11:08 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-10-11 18:11:08 -0700
commit595012ea8b9c6afcc2fc024d5a5e198df765bd75 (patch)
treef645cd4a2be7e8f36da1ca8c391aff8d83ee5f2d /sql
parenta16396df76cc27099011bfb96b28cbdd7f964ca8 (diff)
downloadspark-595012ea8b9c6afcc2fc024d5a5e198df765bd75.tar.gz
spark-595012ea8b9c6afcc2fc024d5a5e198df765bd75.tar.bz2
spark-595012ea8b9c6afcc2fc024d5a5e198df765bd75.zip
[SPARK-11053] Remove use of KVIterator in SortBasedAggregationIterator
SortBasedAggregationIterator uses a KVIterator interface in order to process input rows as key-value pairs, but this use of KVIterator is unnecessary, slightly complicates the code, and might hurt performance. This patch refactors this code to remove the use of this extra layer of iterator wrapping and simplifies other parts of the code in the process. Author: Josh Rosen <joshrosen@databricks.com> Closes #9066 from JoshRosen/sort-iterator-cleanup.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala83
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala89
3 files changed, 33 insertions, 159 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 5f7341e88c..8e0fbd109b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -21,7 +21,6 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.unsafe.KVIterator
import scala.collection.mutable.ArrayBuffer
@@ -412,85 +411,3 @@ abstract class AggregationIterator(
*/
protected def newBuffer: MutableRow
}
-
-object AggregationIterator {
- def kvIterator(
- groupingExpressions: Seq[NamedExpression],
- newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow]): KVIterator[InternalRow, InternalRow] = {
- new KVIterator[InternalRow, InternalRow] {
- private[this] val groupingKeyGenerator = newProjection(groupingExpressions, inputAttributes)
-
- private[this] var groupingKey: InternalRow = _
-
- private[this] var value: InternalRow = _
-
- override def next(): Boolean = {
- if (inputIter.hasNext) {
- // Read the next input row.
- val inputRow = inputIter.next()
- // Get groupingKey based on groupingExpressions.
- groupingKey = groupingKeyGenerator(inputRow)
- // The value is the inputRow.
- value = inputRow
- true
- } else {
- false
- }
- }
-
- override def getKey(): InternalRow = {
- groupingKey
- }
-
- override def getValue(): InternalRow = {
- value
- }
-
- override def close(): Unit = {
- // Do nothing
- }
- }
- }
-
- def unsafeKVIterator(
- groupingExpressions: Seq[NamedExpression],
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow]): KVIterator[UnsafeRow, InternalRow] = {
- new KVIterator[UnsafeRow, InternalRow] {
- private[this] val groupingKeyGenerator =
- UnsafeProjection.create(groupingExpressions, inputAttributes)
-
- private[this] var groupingKey: UnsafeRow = _
-
- private[this] var value: InternalRow = _
-
- override def next(): Boolean = {
- if (inputIter.hasNext) {
- // Read the next input row.
- val inputRow = inputIter.next()
- // Get groupingKey based on groupingExpressions.
- groupingKey = groupingKeyGenerator.apply(inputRow)
- // The value is the inputRow.
- value = inputRow
- true
- } else {
- false
- }
- }
-
- override def getKey(): UnsafeRow = {
- groupingKey
- }
-
- override def getValue(): InternalRow = {
- value
- }
-
- override def close(): Unit = {
- // Do nothing
- }
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
index f4c14a9b35..4d37106e00 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
-import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.StructType
case class SortBasedAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
@@ -79,18 +78,23 @@ case class SortBasedAggregate(
// so return an empty iterator.
Iterator[InternalRow]()
} else {
- val outputIter = SortBasedAggregationIterator.createFromInputIterator(
- groupingExpressions,
+ val groupingKeyProjection = if (UnsafeProjection.canSupport(groupingExpressions)) {
+ UnsafeProjection.create(groupingExpressions, child.output)
+ } else {
+ newMutableProjection(groupingExpressions, child.output)()
+ }
+ val outputIter = new SortBasedAggregationIterator(
+ groupingKeyProjection,
+ groupingExpressions.map(_.toAttribute),
+ child.output,
+ iter,
nonCompleteAggregateExpressions,
nonCompleteAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
initialInputBufferOffset,
resultExpressions,
- newMutableProjection _,
- newProjection _,
- child.output,
- iter,
+ newMutableProjection,
outputsUnsafeRows,
numInputRows,
numOutputRows)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index a9e5d175bf..64c673064f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -21,16 +21,16 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.execution.metric.LongSQLMetric
-import org.apache.spark.unsafe.KVIterator
/**
* An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been
* sorted by values of [[groupingKeyAttributes]].
*/
class SortBasedAggregationIterator(
+ groupingKeyProjection: InternalRow => InternalRow,
groupingKeyAttributes: Seq[Attribute],
valueAttributes: Seq[Attribute],
- inputKVIterator: KVIterator[InternalRow, InternalRow],
+ inputIterator: Iterator[InternalRow],
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
@@ -90,6 +90,22 @@ class SortBasedAggregationIterator(
// The aggregation buffer used by the sort-based aggregation.
private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer
+ protected def initialize(): Unit = {
+ if (inputIterator.hasNext) {
+ initializeBuffer(sortBasedAggregationBuffer)
+ val inputRow = inputIterator.next()
+ nextGroupingKey = groupingKeyProjection(inputRow).copy()
+ firstRowInNextGroup = inputRow.copy()
+ numInputRows += 1
+ sortedInputHasNewGroup = true
+ } else {
+ // This inputIter is empty.
+ sortedInputHasNewGroup = false
+ }
+ }
+
+ initialize()
+
/** Processes rows in the current group. It will stop when it find a new group. */
protected def processCurrentSortedGroup(): Unit = {
currentGroupingKey = nextGroupingKey
@@ -101,18 +117,15 @@ class SortBasedAggregationIterator(
// The search will stop when we see the next group or there is no
// input row left in the iter.
- var hasNext = inputKVIterator.next()
- while (!findNextPartition && hasNext) {
+ while (!findNextPartition && inputIterator.hasNext) {
// Get the grouping key.
- val groupingKey = inputKVIterator.getKey
- val currentRow = inputKVIterator.getValue
+ val currentRow = inputIterator.next()
+ val groupingKey = groupingKeyProjection(currentRow)
numInputRows += 1
// Check if the current row belongs the current input row.
if (currentGroupingKey == groupingKey) {
processRow(sortBasedAggregationBuffer, currentRow)
-
- hasNext = inputKVIterator.next()
} else {
// We find a new group.
findNextPartition = true
@@ -149,68 +162,8 @@ class SortBasedAggregationIterator(
}
}
- protected def initialize(): Unit = {
- if (inputKVIterator.next()) {
- initializeBuffer(sortBasedAggregationBuffer)
-
- nextGroupingKey = inputKVIterator.getKey().copy()
- firstRowInNextGroup = inputKVIterator.getValue().copy()
- numInputRows += 1
- sortedInputHasNewGroup = true
- } else {
- // This inputIter is empty.
- sortedInputHasNewGroup = false
- }
- }
-
- initialize()
-
def outputForEmptyGroupingKeyWithoutInput(): InternalRow = {
initializeBuffer(sortBasedAggregationBuffer)
generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer)
}
}
-
-object SortBasedAggregationIterator {
- // scalastyle:off
- def createFromInputIterator(
- groupingExprs: Seq[NamedExpression],
- nonCompleteAggregateExpressions: Seq[AggregateExpression2],
- nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
- completeAggregateAttributes: Seq[Attribute],
- initialInputBufferOffset: Int,
- resultExpressions: Seq[NamedExpression],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow],
- outputsUnsafeRows: Boolean,
- numInputRows: LongSQLMetric,
- numOutputRows: LongSQLMetric): SortBasedAggregationIterator = {
- val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) {
- AggregationIterator.unsafeKVIterator(
- groupingExprs,
- inputAttributes,
- inputIter).asInstanceOf[KVIterator[InternalRow, InternalRow]]
- } else {
- AggregationIterator.kvIterator(groupingExprs, newProjection, inputAttributes, inputIter)
- }
-
- new SortBasedAggregationIterator(
- groupingExprs.map(_.toAttribute),
- inputAttributes,
- kvIterator,
- nonCompleteAggregateExpressions,
- nonCompleteAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
- initialInputBufferOffset,
- resultExpressions,
- newMutableProjection,
- outputsUnsafeRows,
- numInputRows,
- numOutputRows)
- }
- // scalastyle:on
-}