diff options
author | Davies Liu <davies@databricks.com> | 2015-10-30 15:47:40 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-10-30 15:47:40 -0700 |
commit | 45029bfdea42eb8964f2ba697859687393d2a558 (patch) | |
tree | 45173a40ba6548f69f797d307ffb3a299bf6872e /sql | |
parent | bb5a2af034196620d869fc9b1a400e014e718b8c (diff) | |
download | spark-45029bfdea42eb8964f2ba697859687393d2a558.tar.gz spark-45029bfdea42eb8964f2ba697859687393d2a558.tar.bz2 spark-45029bfdea42eb8964f2ba697859687393d2a558.zip |
[SPARK-11423] remove MapPartitionsWithPreparationRDD
Since we do not need to preserve a page before calling compute(), MapPartitionsWithPreparationRDD is not needed anymore.
This PR basically revert #8543, #8511, #8038, #8011
Author: Davies Liu <davies@databricks.com>
Closes #9381 from davies/remove_prepare2.
Diffstat (limited to 'sql')
4 files changed, 70 insertions, 124 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 889f970034..d4b6d75b4d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -19,9 +19,8 @@ package org.apache.spark.sql.execution; import java.io.IOException; -import com.google.common.annotations.VisibleForTesting; - import org.apache.spark.SparkEnv; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -31,7 +30,6 @@ import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.memory.TaskMemoryManager; /** * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. @@ -218,11 +216,6 @@ public final class UnsafeFixedWidthAggregationMap { return map.getPeakMemoryUsedBytes(); } - @VisibleForTesting - public int getNumDataPages() { - return map.getNumDataPages(); - } - /** * Free the memory associated with this map. This is idempotent and can be called multiple times. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 0d3a4b36c1..15616915f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.TaskContext -import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.types.StructType case class TungstenAggregate( @@ -84,59 +83,39 @@ case class TungstenAggregate( val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") - /** - * Set up the underlying unsafe data structures used before computing the parent partition. - * This makes sure our iterator is not starved by other operators in the same task. - */ - def preparePartition(): TungstenAggregationIterator = { - new TungstenAggregationIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - child.output, - testFallbackStartsAt, - numInputRows, - numOutputRows, - dataSize, - spillSize) - } + child.execute().mapPartitions { iter => - /** Compute a partition using the iterator already set up previously. */ - def executePartition( - context: TaskContext, - partitionIndex: Int, - aggregationIterator: TungstenAggregationIterator, - parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = { - val hasInput = parentIterator.hasNext - if (!hasInput) { - // We're not using the underlying map, so we just can free it here - aggregationIterator.free() - if (groupingExpressions.isEmpty) { + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator.empty + } else { + val aggregationIterator = + new TungstenAggregationIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + child.output, + iter, + testFallbackStartsAt, + numInputRows, + numOutputRows, + dataSize, + spillSize) + if (!hasInput && groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) } else { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. - Iterator.empty + aggregationIterator } - } else { - aggregationIterator.start(parentIterator) - aggregationIterator } } - - // Note: we need to set up the iterator in each partition before computing the - // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747). - val resultRdd = { - new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator]( - child.execute(), preparePartition, executePartition, preservesPartitioning = true) - } - resultRdd.asInstanceOf[RDD[InternalRow]] } override def simpleString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index fb2fc98e34..713a4db0cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -74,6 +74,8 @@ import org.apache.spark.sql.types.StructType * the function used to create mutable projections. * @param originalInputAttributes * attributes of representing input rows from `inputIter`. + * @param inputIter + * the iterator containing input [[UnsafeRow]]s. */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], @@ -85,6 +87,7 @@ class TungstenAggregationIterator( resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow], testFallbackStartsAt: Option[Int], numInputRows: LongSQLMetric, numOutputRows: LongSQLMetric, @@ -92,9 +95,6 @@ class TungstenAggregationIterator( spillSize: LongSQLMetric) extends Iterator[UnsafeRow] with Logging { - // The parent partition iterator, to be initialized later in `start` - private[this] var inputIter: Iterator[InternalRow] = null - /////////////////////////////////////////////////////////////////////////// // Part 1: Initializing aggregate functions. /////////////////////////////////////////////////////////////////////////// @@ -486,15 +486,11 @@ class TungstenAggregationIterator( false // disable tracking of performance metrics ) - // Exposed for testing - private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap - // The function used to read and process input rows. When processing input rows, // it first uses hash-based aggregation by putting groups and their buffers in // hashMap. If we could not allocate more memory for the map, we switch to // sort-based aggregation (by calling switchToSortBasedAggregation). private def processInputs(): Unit = { - assert(inputIter != null, "attempted to process input when iterator was null") if (groupingExpressions.isEmpty) { // If there is no grouping expressions, we can just reuse the same buffer over and over again. // Note that it would be better to eliminate the hash map entirely in the future. @@ -526,7 +522,6 @@ class TungstenAggregationIterator( // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have // been processed. private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { - assert(inputIter != null, "attempted to process input when iterator was null") var i = 0 while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() @@ -567,15 +562,11 @@ class TungstenAggregationIterator( * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. */ private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = { - assert(inputIter != null, "attempted to process input when iterator was null") logInfo("falling back to sort based aggregation.") // Step 1: Get the ExternalSorter containing sorted entries of the map. externalSorter = hashMap.destructAndCreateExternalSorter() - // Step 2: Free the memory used by the map. - hashMap.free() - - // Step 3: If we have aggregate function with mode Partial or Complete, + // Step 2: If we have aggregate function with mode Partial or Complete, // we need to process input rows to get aggregation buffer. // So, later in the sort-based aggregation iterator, we can do merge. // If aggregate functions are with mode Final and PartialMerge, @@ -770,31 +761,27 @@ class TungstenAggregationIterator( /** * Start processing input rows. - * Only after this method is called will this iterator be non-empty. */ - def start(parentIter: Iterator[InternalRow]): Unit = { - inputIter = parentIter - testFallbackStartsAt match { - case None => - processInputs() - case Some(fallbackStartsAt) => - // This is the testing path. processInputsWithControlledFallback is same as processInputs - // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows - // have been processed. - processInputsWithControlledFallback(fallbackStartsAt) - } + testFallbackStartsAt match { + case None => + processInputs() + case Some(fallbackStartsAt) => + // This is the testing path. processInputsWithControlledFallback is same as processInputs + // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows + // have been processed. + processInputsWithControlledFallback(fallbackStartsAt) + } - // If we did not switch to sort-based aggregation in processInputs, - // we pre-load the first key-value pair from the map (to make hasNext idempotent). - if (!sortBased) { - // First, set aggregationBufferMapIterator. - aggregationBufferMapIterator = hashMap.iterator() - // Pre-load the first key-value pair from the aggregationBufferMapIterator. - mapIteratorHasNext = aggregationBufferMapIterator.next() - // If the map is empty, we just free it. - if (!mapIteratorHasNext) { - hashMap.free() - } + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + if (!sortBased) { + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!mapIteratorHasNext) { + hashMap.free() } } @@ -868,13 +855,16 @@ class TungstenAggregationIterator( * Generate a output row when there is no input and there is no grouping expression. */ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { - assert(groupingExpressions.isEmpty) - assert(inputIter == null) - generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer) - } - - /** Free memory used in the underlying map. */ - def free(): Unit = { - hashMap.free() + if (groupingExpressions.isEmpty) { + sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) + // We create a output row and copy it. So, we can free the map. + val resultCopy = + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() + hashMap.free() + resultCopy + } else { + throw new IllegalStateException( + "This method should not be called when groupingExpressions is not empty.") + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index dd92dda480..1a3832a698 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} +import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines various sort operators. @@ -77,6 +77,7 @@ case class Sort( * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will * spill every `frequency` records. */ + case class TungstenSort( sortOrder: Seq[SortOrder], global: Boolean, @@ -106,11 +107,7 @@ case class TungstenSort( val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") - /** - * Set up the sorter in each partition before computing the parent partition. - * This makes sure our sorter is not starved by other sorters used in the same task. - */ - def preparePartition(): UnsafeExternalRowSorter = { + child.execute().mapPartitions { iter => val ordering = newOrdering(sortOrder, childOutput) // The comparator for comparing prefix @@ -131,33 +128,20 @@ case class TungstenSort( if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } - sorter - } - /** Compute a partition using the sorter already set up previously. */ - def executePartition( - taskContext: TaskContext, - partitionIndex: Int, - sorter: UnsafeExternalRowSorter, - parentIterator: Iterator[InternalRow]): Iterator[InternalRow] = { // Remember spill data size of this task before execute this operator so that we can // figure out how many bytes we spilled for this operator. val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled - val sortedIterator = sorter.sort(parentIterator.asInstanceOf[Iterator[UnsafeRow]]) + val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) dataSize += sorter.getPeakMemoryUsage spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore - taskContext.internalMetricsToAccumulators( + TaskContext.get().internalMetricsToAccumulators( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage) sortedIterator } - - // Note: we need to set up the external sorter in each partition before computing - // the parent partition, so we cannot simply use `mapPartitions` here (SPARK-9709). - new MapPartitionsWithPreparationRDD[InternalRow, InternalRow, UnsafeExternalRowSorter]( - child.execute(), preparePartition, executePartition, preservesPartitioning = true) } } |