aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2015-08-12 10:08:35 -0700
committerReynold Xin <rxin@databricks.com>2015-08-12 10:08:35 -0700
commite0110792ef71ebfd3727b970346a2e13695990a4 (patch)
treebf2a56847391ed0e2ead0a589d85871176b1ac3c /sql
parent66d87c1d76bea2b81993156ac1fa7dad6c312ebf (diff)
downloadspark-e0110792ef71ebfd3727b970346a2e13695990a4.tar.gz
spark-e0110792ef71ebfd3727b970346a2e13695990a4.tar.bz2
spark-e0110792ef71ebfd3727b970346a2e13695990a4.zip
[SPARK-9747] [SQL] Avoid starving an unsafe operator in aggregation
This is the sister patch to #8011, but for aggregation. In a nutshell: create the `TungstenAggregationIterator` before computing the parent partition. Internally this creates a `BytesToBytesMap` which acquires a page in the constructor as of this patch. This ensures that the aggregation operator is not starved since we reserve at least 1 page in advance. rxin yhuai Author: Andrew Or <andrew@databricks.com> Closes #8038 from andrewor14/unsafe-starve-memory-agg.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala72
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala88
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala56
4 files changed, 162 insertions, 61 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 5cce41d5a7..09511ff35f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution;
import java.io.IOException;
+import com.google.common.annotations.VisibleForTesting;
+
import org.apache.spark.SparkEnv;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
@@ -220,6 +222,11 @@ public final class UnsafeFixedWidthAggregationMap {
return map.getPeakMemoryUsedBytes();
}
+ @VisibleForTesting
+ public int getNumDataPages() {
+ return map.getNumDataPages();
+ }
+
/**
* Free the memory associated with this map. This is idempotent and can be called multiple times.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 6b5935a7ce..c40ca97379 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -17,12 +17,13 @@
package org.apache.spark.sql.execution.aggregate
-import org.apache.spark.rdd.RDD
+import org.apache.spark.TaskContext
+import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
+import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -68,35 +69,56 @@ case class TungstenAggregate(
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numInputRows = longMetric("numInputRows")
val numOutputRows = longMetric("numOutputRows")
- child.execute().mapPartitions { iter =>
- val hasInput = iter.hasNext
- if (!hasInput && groupingExpressions.nonEmpty) {
- // This is a grouped aggregate and the input iterator is empty,
- // so return an empty iterator.
- Iterator.empty.asInstanceOf[Iterator[UnsafeRow]]
- } else {
- val aggregationIterator =
- new TungstenAggregationIterator(
- groupingExpressions,
- nonCompleteAggregateExpressions,
- completeAggregateExpressions,
- initialInputBufferOffset,
- resultExpressions,
- newMutableProjection,
- child.output,
- iter,
- testFallbackStartsAt,
- numInputRows,
- numOutputRows)
-
- if (!hasInput && groupingExpressions.isEmpty) {
+
+ /**
+ * Set up the underlying unsafe data structures used before computing the parent partition.
+ * This makes sure our iterator is not starved by other operators in the same task.
+ */
+ def preparePartition(): TungstenAggregationIterator = {
+ new TungstenAggregationIterator(
+ groupingExpressions,
+ nonCompleteAggregateExpressions,
+ completeAggregateExpressions,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection,
+ child.output,
+ testFallbackStartsAt,
+ numInputRows,
+ numOutputRows)
+ }
+
+ /** Compute a partition using the iterator already set up previously. */
+ def executePartition(
+ context: TaskContext,
+ partitionIndex: Int,
+ aggregationIterator: TungstenAggregationIterator,
+ parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = {
+ val hasInput = parentIterator.hasNext
+ if (!hasInput) {
+ // We're not using the underlying map, so we just can free it here
+ aggregationIterator.free()
+ if (groupingExpressions.isEmpty) {
numOutputRows += 1
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
} else {
- aggregationIterator
+ // This is a grouped aggregate and the input iterator is empty,
+ // so return an empty iterator.
+ Iterator[UnsafeRow]()
}
+ } else {
+ aggregationIterator.start(parentIterator)
+ aggregationIterator
}
}
+
+ // Note: we need to set up the iterator in each partition before computing the
+ // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747).
+ val resultRdd = {
+ new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator](
+ child.execute(), preparePartition, executePartition, preservesPartitioning = true)
+ }
+ resultRdd.asInstanceOf[RDD[InternalRow]]
}
override def simpleString: String = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 1f383dd044..af7e0fcedb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -72,8 +72,6 @@ import org.apache.spark.sql.types.StructType
* the function used to create mutable projections.
* @param originalInputAttributes
* attributes of representing input rows from `inputIter`.
- * @param inputIter
- * the iterator containing input [[UnsafeRow]]s.
*/
class TungstenAggregationIterator(
groupingExpressions: Seq[NamedExpression],
@@ -83,12 +81,14 @@ class TungstenAggregationIterator(
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow],
testFallbackStartsAt: Option[Int],
numInputRows: LongSQLMetric,
numOutputRows: LongSQLMetric)
extends Iterator[UnsafeRow] with Logging {
+ // The parent partition iterator, to be initialized later in `start`
+ private[this] var inputIter: Iterator[InternalRow] = null
+
///////////////////////////////////////////////////////////////////////////
// Part 1: Initializing aggregate functions.
///////////////////////////////////////////////////////////////////////////
@@ -348,11 +348,15 @@ class TungstenAggregationIterator(
false // disable tracking of performance metrics
)
+ // Exposed for testing
+ private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap
+
// The function used to read and process input rows. When processing input rows,
// it first uses hash-based aggregation by putting groups and their buffers in
// hashMap. If we could not allocate more memory for the map, we switch to
// sort-based aggregation (by calling switchToSortBasedAggregation).
private def processInputs(): Unit = {
+ assert(inputIter != null, "attempted to process input when iterator was null")
while (!sortBased && inputIter.hasNext) {
val newInput = inputIter.next()
numInputRows += 1
@@ -372,6 +376,7 @@ class TungstenAggregationIterator(
// that it switch to sort-based aggregation after `fallbackStartsAt` input rows have
// been processed.
private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = {
+ assert(inputIter != null, "attempted to process input when iterator was null")
var i = 0
while (!sortBased && inputIter.hasNext) {
val newInput = inputIter.next()
@@ -412,6 +417,7 @@ class TungstenAggregationIterator(
* Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
*/
private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = {
+ assert(inputIter != null, "attempted to process input when iterator was null")
logInfo("falling back to sort based aggregation.")
// Step 1: Get the ExternalSorter containing sorted entries of the map.
externalSorter = hashMap.destructAndCreateExternalSorter()
@@ -431,6 +437,11 @@ class TungstenAggregationIterator(
case _ => false
}
+ // Note: Since we spill the sorter's contents immediately after creating it, we must insert
+ // something into the sorter here to ensure that we acquire at least a page of memory.
+ // This is done through `externalSorter.insertKV`, which will trigger the page allocation.
+ // Otherwise, children operators may steal the window of opportunity and starve our sorter.
+
if (needsProcess) {
// First, we create a buffer.
val buffer = createNewAggregationBuffer()
@@ -588,27 +599,33 @@ class TungstenAggregationIterator(
// have not switched to sort-based aggregation.
///////////////////////////////////////////////////////////////////////////
- // Starts to process input rows.
- testFallbackStartsAt match {
- case None =>
- processInputs()
- case Some(fallbackStartsAt) =>
- // This is the testing path. processInputsWithControlledFallback is same as processInputs
- // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
- // have been processed.
- processInputsWithControlledFallback(fallbackStartsAt)
- }
+ /**
+ * Start processing input rows.
+ * Only after this method is called will this iterator be non-empty.
+ */
+ def start(parentIter: Iterator[InternalRow]): Unit = {
+ inputIter = parentIter
+ testFallbackStartsAt match {
+ case None =>
+ processInputs()
+ case Some(fallbackStartsAt) =>
+ // This is the testing path. processInputsWithControlledFallback is same as processInputs
+ // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
+ // have been processed.
+ processInputsWithControlledFallback(fallbackStartsAt)
+ }
- // If we did not switch to sort-based aggregation in processInputs,
- // we pre-load the first key-value pair from the map (to make hasNext idempotent).
- if (!sortBased) {
- // First, set aggregationBufferMapIterator.
- aggregationBufferMapIterator = hashMap.iterator()
- // Pre-load the first key-value pair from the aggregationBufferMapIterator.
- mapIteratorHasNext = aggregationBufferMapIterator.next()
- // If the map is empty, we just free it.
- if (!mapIteratorHasNext) {
- hashMap.free()
+ // If we did not switch to sort-based aggregation in processInputs,
+ // we pre-load the first key-value pair from the map (to make hasNext idempotent).
+ if (!sortBased) {
+ // First, set aggregationBufferMapIterator.
+ aggregationBufferMapIterator = hashMap.iterator()
+ // Pre-load the first key-value pair from the aggregationBufferMapIterator.
+ mapIteratorHasNext = aggregationBufferMapIterator.next()
+ // If the map is empty, we just free it.
+ if (!mapIteratorHasNext) {
+ hashMap.free()
+ }
}
}
@@ -673,21 +690,20 @@ class TungstenAggregationIterator(
}
///////////////////////////////////////////////////////////////////////////
- // Part 8: A utility function used to generate a output row when there is no
- // input and there is no grouping expression.
+ // Part 8: Utility functions
///////////////////////////////////////////////////////////////////////////
+ /**
+ * Generate a output row when there is no input and there is no grouping expression.
+ */
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
- if (groupingExpressions.isEmpty) {
- sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
- // We create a output row and copy it. So, we can free the map.
- val resultCopy =
- generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy()
- hashMap.free()
- resultCopy
- } else {
- throw new IllegalStateException(
- "This method should not be called when groupingExpressions is not empty.")
- }
+ assert(groupingExpressions.isEmpty)
+ assert(inputIter == null)
+ generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer)
+ }
+
+ /** Free memory used in the underlying map. */
+ def free(): Unit = {
+ hashMap.free()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
new file mode 100644
index 0000000000..ac22c2f3c0
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.unsafe.memory.TaskMemoryManager
+
+class TungstenAggregationIteratorSuite extends SparkFunSuite {
+
+ test("memory acquired on construction") {
+ // set up environment
+ val ctx = TestSQLContext
+
+ val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
+ TaskContext.setTaskContext(taskContext)
+
+ // Assert that a page is allocated before processing starts
+ var iter: TungstenAggregationIterator = null
+ try {
+ val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => {
+ () => new InterpretedMutableProjection(expr, schema)
+ }
+ val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy")
+ iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0,
+ Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
+ val numPages = iter.getHashMap.getNumDataPages
+ assert(numPages === 1)
+ } finally {
+ // Clean up
+ if (iter != null) {
+ iter.free()
+ }
+ TaskContext.unset()
+ }
+ }
+}