aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-10-30 15:47:40 -0700
committerDavies Liu <davies.liu@gmail.com>2015-10-30 15:47:40 -0700
commit45029bfdea42eb8964f2ba697859687393d2a558 (patch)
tree45173a40ba6548f69f797d307ffb3a299bf6872e
parentbb5a2af034196620d869fc9b1a400e014e718b8c (diff)
downloadspark-45029bfdea42eb8964f2ba697859687393d2a558.tar.gz
spark-45029bfdea42eb8964f2ba697859687393d2a558.tar.bz2
spark-45029bfdea42eb8964f2ba697859687393d2a558.zip
[SPARK-11423] remove MapPartitionsWithPreparationRDD
Since we do not need to preserve a page before calling compute(), MapPartitionsWithPreparationRDD is not needed anymore. This PR basically revert #8543, #8511, #8038, #8011 Author: Davies Liu <davies@databricks.com> Closes #9381 from davies/remove_prepare2.
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala66
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala13
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala66
-rw-r--r--project/MimaExcludes.scala6
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala79
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala78
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala28
8 files changed, 75 insertions, 270 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala
deleted file mode 100644
index 417ff5278d..0000000000
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import scala.collection.mutable.ArrayBuffer
-import scala.reflect.ClassTag
-
-import org.apache.spark.{Partition, Partitioner, TaskContext}
-
-/**
- * An RDD that applies a user provided function to every partition of the parent RDD, and
- * additionally allows the user to prepare each partition before computing the parent partition.
- */
-private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M: ClassTag](
- prev: RDD[T],
- preparePartition: () => M,
- executePartition: (TaskContext, Int, M, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean = false)
- extends RDD[U](prev) {
-
- override val partitioner: Option[Partitioner] = {
- if (preservesPartitioning) firstParent[T].partitioner else None
- }
-
- override def getPartitions: Array[Partition] = firstParent[T].partitions
-
- // In certain join operations, prepare can be called on the same partition multiple times.
- // In this case, we need to ensure that each call to compute gets a separate prepare argument.
- private[this] val preparedArguments: ArrayBuffer[M] = new ArrayBuffer[M]
-
- /**
- * Prepare a partition for a single call to compute.
- */
- def prepare(): Unit = {
- preparedArguments += preparePartition()
- }
-
- /**
- * Prepare a partition before computing it from its parent.
- */
- override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
- val prepared =
- if (preparedArguments.isEmpty) {
- preparePartition()
- } else {
- preparedArguments.remove(0)
- }
- val parentIterator = firstParent[T].iterator(partition, context)
- executePartition(context, partition.index, prepared, parentIterator)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
index 70bf04de64..4333a679c8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -73,16 +73,6 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
super.clearDependencies()
rdds = null
}
-
- /**
- * Call the prepare method of every parent that has one.
- * This is needed for reserving execution memory in advance.
- */
- protected def tryPrepareParents(): Unit = {
- rdds.collect {
- case rdd: MapPartitionsWithPreparationRDD[_, _, _] => rdd.prepare()
- }
- }
}
private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag](
@@ -94,7 +84,6 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
- tryPrepareParents()
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context))
}
@@ -118,7 +107,6 @@ private[spark] class ZippedPartitionsRDD3
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
- tryPrepareParents()
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdd1.iterator(partitions(0), context),
rdd2.iterator(partitions(1), context),
@@ -146,7 +134,6 @@ private[spark] class ZippedPartitionsRDD4
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
- tryPrepareParents()
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdd1.iterator(partitions(0), context),
rdd2.iterator(partitions(1), context),
diff --git a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala
deleted file mode 100644
index e281e817e4..0000000000
--- a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import scala.collection.mutable
-
-import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TaskContext}
-
-class MapPartitionsWithPreparationRDDSuite extends SparkFunSuite with LocalSparkContext {
-
- test("prepare called before parent partition is computed") {
- sc = new SparkContext("local", "test")
-
- // Have the parent partition push a number to the list
- val parent = sc.parallelize(1 to 100, 1).mapPartitions { iter =>
- TestObject.things.append(20)
- iter
- }
-
- // Push a different number during the prepare phase
- val preparePartition = () => { TestObject.things.append(10) }
-
- // Push yet another number during the execution phase
- val executePartition = (
- taskContext: TaskContext,
- partitionIndex: Int,
- notUsed: Unit,
- parentIterator: Iterator[Int]) => {
- TestObject.things.append(30)
- TestObject.things.iterator
- }
-
- // Verify that the numbers are pushed in the order expected
- val rdd = new MapPartitionsWithPreparationRDD[Int, Int, Unit](
- parent, preparePartition, executePartition)
- val result = rdd.collect()
- assert(result === Array(10, 20, 30))
-
- TestObject.things.clear()
- // Zip two of these RDDs, both should be prepared before the parent is executed
- val rdd2 = new MapPartitionsWithPreparationRDD[Int, Int, Unit](
- parent, preparePartition, executePartition)
- val result2 = rdd.zipPartitions(rdd2)((a, b) => a).collect()
- assert(result2 === Array(10, 10, 20, 30, 20, 30))
- }
-
-}
-
-private object TestObject {
- val things = new mutable.ListBuffer[Int]
-}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index b5e661d3ec..8282f7ea62 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -107,7 +107,11 @@ object MimaExcludes {
"org.apache.spark.sql.SQLContext.createSession")
) ++ Seq(
ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.SparkContext.preferredNodeLocationData_=")
+ "org.apache.spark.SparkContext.preferredNodeLocationData_="),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$")
)
case v if v.startsWith("1.5") =>
Seq(
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 889f970034..d4b6d75b4d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -19,9 +19,8 @@ package org.apache.spark.sql.execution;
import java.io.IOException;
-import com.google.common.annotations.VisibleForTesting;
-
import org.apache.spark.SparkEnv;
+import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -31,7 +30,6 @@ import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryLocation;
-import org.apache.spark.memory.TaskMemoryManager;
/**
* Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
@@ -218,11 +216,6 @@ public final class UnsafeFixedWidthAggregationMap {
return map.getPeakMemoryUsedBytes();
}
- @VisibleForTesting
- public int getNumDataPages() {
- return map.getNumDataPages();
- }
-
/**
* Free the memory associated with this map. This is idempotent and can be called multiple times.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 0d3a4b36c1..15616915f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -17,15 +17,14 @@
package org.apache.spark.sql.execution.aggregate
-import org.apache.spark.TaskContext
-import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.types.StructType
case class TungstenAggregate(
@@ -84,59 +83,39 @@ case class TungstenAggregate(
val dataSize = longMetric("dataSize")
val spillSize = longMetric("spillSize")
- /**
- * Set up the underlying unsafe data structures used before computing the parent partition.
- * This makes sure our iterator is not starved by other operators in the same task.
- */
- def preparePartition(): TungstenAggregationIterator = {
- new TungstenAggregationIterator(
- groupingExpressions,
- nonCompleteAggregateExpressions,
- nonCompleteAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
- initialInputBufferOffset,
- resultExpressions,
- newMutableProjection,
- child.output,
- testFallbackStartsAt,
- numInputRows,
- numOutputRows,
- dataSize,
- spillSize)
- }
+ child.execute().mapPartitions { iter =>
- /** Compute a partition using the iterator already set up previously. */
- def executePartition(
- context: TaskContext,
- partitionIndex: Int,
- aggregationIterator: TungstenAggregationIterator,
- parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = {
- val hasInput = parentIterator.hasNext
- if (!hasInput) {
- // We're not using the underlying map, so we just can free it here
- aggregationIterator.free()
- if (groupingExpressions.isEmpty) {
+ val hasInput = iter.hasNext
+ if (!hasInput && groupingExpressions.nonEmpty) {
+ // This is a grouped aggregate and the input iterator is empty,
+ // so return an empty iterator.
+ Iterator.empty
+ } else {
+ val aggregationIterator =
+ new TungstenAggregationIterator(
+ groupingExpressions,
+ nonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection,
+ child.output,
+ iter,
+ testFallbackStartsAt,
+ numInputRows,
+ numOutputRows,
+ dataSize,
+ spillSize)
+ if (!hasInput && groupingExpressions.isEmpty) {
numOutputRows += 1
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
} else {
- // This is a grouped aggregate and the input iterator is empty,
- // so return an empty iterator.
- Iterator.empty
+ aggregationIterator
}
- } else {
- aggregationIterator.start(parentIterator)
- aggregationIterator
}
}
-
- // Note: we need to set up the iterator in each partition before computing the
- // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747).
- val resultRdd = {
- new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator](
- child.execute(), preparePartition, executePartition, preservesPartitioning = true)
- }
- resultRdd.asInstanceOf[RDD[InternalRow]]
}
override def simpleString: String = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index fb2fc98e34..713a4db0cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -74,6 +74,8 @@ import org.apache.spark.sql.types.StructType
* the function used to create mutable projections.
* @param originalInputAttributes
* attributes of representing input rows from `inputIter`.
+ * @param inputIter
+ * the iterator containing input [[UnsafeRow]]s.
*/
class TungstenAggregationIterator(
groupingExpressions: Seq[NamedExpression],
@@ -85,6 +87,7 @@ class TungstenAggregationIterator(
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow],
testFallbackStartsAt: Option[Int],
numInputRows: LongSQLMetric,
numOutputRows: LongSQLMetric,
@@ -92,9 +95,6 @@ class TungstenAggregationIterator(
spillSize: LongSQLMetric)
extends Iterator[UnsafeRow] with Logging {
- // The parent partition iterator, to be initialized later in `start`
- private[this] var inputIter: Iterator[InternalRow] = null
-
///////////////////////////////////////////////////////////////////////////
// Part 1: Initializing aggregate functions.
///////////////////////////////////////////////////////////////////////////
@@ -486,15 +486,11 @@ class TungstenAggregationIterator(
false // disable tracking of performance metrics
)
- // Exposed for testing
- private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap
-
// The function used to read and process input rows. When processing input rows,
// it first uses hash-based aggregation by putting groups and their buffers in
// hashMap. If we could not allocate more memory for the map, we switch to
// sort-based aggregation (by calling switchToSortBasedAggregation).
private def processInputs(): Unit = {
- assert(inputIter != null, "attempted to process input when iterator was null")
if (groupingExpressions.isEmpty) {
// If there is no grouping expressions, we can just reuse the same buffer over and over again.
// Note that it would be better to eliminate the hash map entirely in the future.
@@ -526,7 +522,6 @@ class TungstenAggregationIterator(
// that it switch to sort-based aggregation after `fallbackStartsAt` input rows have
// been processed.
private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = {
- assert(inputIter != null, "attempted to process input when iterator was null")
var i = 0
while (!sortBased && inputIter.hasNext) {
val newInput = inputIter.next()
@@ -567,15 +562,11 @@ class TungstenAggregationIterator(
* Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
*/
private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = {
- assert(inputIter != null, "attempted to process input when iterator was null")
logInfo("falling back to sort based aggregation.")
// Step 1: Get the ExternalSorter containing sorted entries of the map.
externalSorter = hashMap.destructAndCreateExternalSorter()
- // Step 2: Free the memory used by the map.
- hashMap.free()
-
- // Step 3: If we have aggregate function with mode Partial or Complete,
+ // Step 2: If we have aggregate function with mode Partial or Complete,
// we need to process input rows to get aggregation buffer.
// So, later in the sort-based aggregation iterator, we can do merge.
// If aggregate functions are with mode Final and PartialMerge,
@@ -770,31 +761,27 @@ class TungstenAggregationIterator(
/**
* Start processing input rows.
- * Only after this method is called will this iterator be non-empty.
*/
- def start(parentIter: Iterator[InternalRow]): Unit = {
- inputIter = parentIter
- testFallbackStartsAt match {
- case None =>
- processInputs()
- case Some(fallbackStartsAt) =>
- // This is the testing path. processInputsWithControlledFallback is same as processInputs
- // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
- // have been processed.
- processInputsWithControlledFallback(fallbackStartsAt)
- }
+ testFallbackStartsAt match {
+ case None =>
+ processInputs()
+ case Some(fallbackStartsAt) =>
+ // This is the testing path. processInputsWithControlledFallback is same as processInputs
+ // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
+ // have been processed.
+ processInputsWithControlledFallback(fallbackStartsAt)
+ }
- // If we did not switch to sort-based aggregation in processInputs,
- // we pre-load the first key-value pair from the map (to make hasNext idempotent).
- if (!sortBased) {
- // First, set aggregationBufferMapIterator.
- aggregationBufferMapIterator = hashMap.iterator()
- // Pre-load the first key-value pair from the aggregationBufferMapIterator.
- mapIteratorHasNext = aggregationBufferMapIterator.next()
- // If the map is empty, we just free it.
- if (!mapIteratorHasNext) {
- hashMap.free()
- }
+ // If we did not switch to sort-based aggregation in processInputs,
+ // we pre-load the first key-value pair from the map (to make hasNext idempotent).
+ if (!sortBased) {
+ // First, set aggregationBufferMapIterator.
+ aggregationBufferMapIterator = hashMap.iterator()
+ // Pre-load the first key-value pair from the aggregationBufferMapIterator.
+ mapIteratorHasNext = aggregationBufferMapIterator.next()
+ // If the map is empty, we just free it.
+ if (!mapIteratorHasNext) {
+ hashMap.free()
}
}
@@ -868,13 +855,16 @@ class TungstenAggregationIterator(
* Generate a output row when there is no input and there is no grouping expression.
*/
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
- assert(groupingExpressions.isEmpty)
- assert(inputIter == null)
- generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer)
- }
-
- /** Free memory used in the underlying map. */
- def free(): Unit = {
- hashMap.free()
+ if (groupingExpressions.isEmpty) {
+ sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
+ // We create a output row and copy it. So, we can free the map.
+ val resultCopy =
+ generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy()
+ hashMap.free()
+ resultCopy
+ } else {
+ throw new IllegalStateException(
+ "This method should not be called when groupingExpressions is not empty.")
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index dd92dda480..1a3832a698 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution
-import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
@@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
-import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext}
+import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines various sort operators.
@@ -77,6 +77,7 @@ case class Sort(
* @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will
* spill every `frequency` records.
*/
+
case class TungstenSort(
sortOrder: Seq[SortOrder],
global: Boolean,
@@ -106,11 +107,7 @@ case class TungstenSort(
val dataSize = longMetric("dataSize")
val spillSize = longMetric("spillSize")
- /**
- * Set up the sorter in each partition before computing the parent partition.
- * This makes sure our sorter is not starved by other sorters used in the same task.
- */
- def preparePartition(): UnsafeExternalRowSorter = {
+ child.execute().mapPartitions { iter =>
val ordering = newOrdering(sortOrder, childOutput)
// The comparator for comparing prefix
@@ -131,33 +128,20 @@ case class TungstenSort(
if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
}
- sorter
- }
- /** Compute a partition using the sorter already set up previously. */
- def executePartition(
- taskContext: TaskContext,
- partitionIndex: Int,
- sorter: UnsafeExternalRowSorter,
- parentIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
// Remember spill data size of this task before execute this operator so that we can
// figure out how many bytes we spilled for this operator.
val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled
- val sortedIterator = sorter.sort(parentIterator.asInstanceOf[Iterator[UnsafeRow]])
+ val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
dataSize += sorter.getPeakMemoryUsage
spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore
- taskContext.internalMetricsToAccumulators(
+ TaskContext.get().internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
sortedIterator
}
-
- // Note: we need to set up the external sorter in each partition before computing
- // the parent partition, so we cannot simply use `mapPartitions` here (SPARK-9709).
- new MapPartitionsWithPreparationRDD[InternalRow, InternalRow, UnsafeExternalRowSorter](
- child.execute(), preparePartition, executePartition, preservesPartitioning = true)
}
}