aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2015-08-12 10:08:35 -0700
committerReynold Xin <rxin@databricks.com>2015-08-12 10:08:47 -0700
commit4c6b1296d20f594f71e63b0772b5290ef21ddd21 (patch)
tree23daff9f711a6a548366da26bea61e53bd3bef08
parent2d86faddd87b6e61565cbdf18dadaf4aeb2b223e (diff)
downloadspark-4c6b1296d20f594f71e63b0772b5290ef21ddd21.tar.gz
spark-4c6b1296d20f594f71e63b0772b5290ef21ddd21.tar.bz2
spark-4c6b1296d20f594f71e63b0772b5290ef21ddd21.zip
[SPARK-9747] [SQL] Avoid starving an unsafe operator in aggregation
This is the sister patch to #8011, but for aggregation. In a nutshell: create the `TungstenAggregationIterator` before computing the parent partition. Internally this creates a `BytesToBytesMap` which acquires a page in the constructor as of this patch. This ensures that the aggregation operator is not starved since we reserve at least 1 page in advance. rxin yhuai Author: Andrew Or <andrew@databricks.com> Closes #8038 from andrewor14/unsafe-starve-memory-agg. (cherry picked from commit e0110792ef71ebfd3727b970346a2e13695990a4) Signed-off-by: Reynold Xin <rxin@databricks.com>
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java34
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java9
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java11
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala72
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala88
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala56
7 files changed, 201 insertions, 76 deletions
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 85b46ec8bf..87ed47e88c 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -193,6 +193,11 @@ public final class BytesToBytesMap {
TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
}
allocate(initialCapacity);
+
+ // Acquire a new page as soon as we construct the map to ensure that we have at least
+ // one page to work with. Otherwise, other operators in the same task may starve this
+ // map (SPARK-9747).
+ acquireNewPage();
}
public BytesToBytesMap(
@@ -574,16 +579,9 @@ public final class BytesToBytesMap {
final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
}
- final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
- if (memoryGranted != pageSizeBytes) {
- shuffleMemoryManager.release(memoryGranted);
- logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
+ if (!acquireNewPage()) {
return false;
}
- MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
- dataPages.add(newPage);
- pageCursor = 0;
- currentDataPage = newPage;
dataPage = currentDataPage;
dataPageBaseObject = currentDataPage.getBaseObject();
dataPageInsertOffset = currentDataPage.getBaseOffset();
@@ -643,6 +641,24 @@ public final class BytesToBytesMap {
}
/**
+ * Acquire a new page from the {@link ShuffleMemoryManager}.
+ * @return whether there is enough space to allocate the new page.
+ */
+ private boolean acquireNewPage() {
+ final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+ if (memoryGranted != pageSizeBytes) {
+ shuffleMemoryManager.release(memoryGranted);
+ logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
+ return false;
+ }
+ MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
+ dataPages.add(newPage);
+ pageCursor = 0;
+ currentDataPage = newPage;
+ return true;
+ }
+
+ /**
* Allocate new data structures for this map. When calling this outside of the constructor,
* make sure to keep references to the old data structures so that you can free them.
*
@@ -748,7 +764,7 @@ public final class BytesToBytesMap {
}
@VisibleForTesting
- int getNumDataPages() {
+ public int getNumDataPages() {
return dataPages.size();
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 9601aafe55..fc364e0a89 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -132,16 +132,15 @@ public final class UnsafeExternalSorter {
if (existingInMemorySorter == null) {
initializeForWriting();
+ // Acquire a new page as soon as we construct the sorter to ensure that we have at
+ // least one page to work with. Otherwise, other operators in the same task may starve
+ // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter.
+ acquireNewPage();
} else {
this.isInMemSorterExternal = true;
this.inMemSorter = existingInMemorySorter;
}
- // Acquire a new page as soon as we construct the sorter to ensure that we have at
- // least one page to work with. Otherwise, other operators in the same task may starve
- // this sorter (SPARK-9709).
- acquireNewPage();
-
// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
// does not fully consume the sorter's output (e.g. sort followed by limit).
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 1a79c20c35..ab480b60ad 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -543,7 +543,7 @@ public abstract class AbstractBytesToBytesMapSuite {
Platform.LONG_ARRAY_OFFSET,
8);
newPeakMemory = map.getPeakMemoryUsedBytes();
- if (i % numRecordsPerPage == 0) {
+ if (i % numRecordsPerPage == 0 && i > 0) {
// We allocated a new page for this record, so peak memory should change
assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
} else {
@@ -561,4 +561,13 @@ public abstract class AbstractBytesToBytesMapSuite {
map.free();
}
}
+
+ @Test
+ public void testAcquirePageInConstructor() {
+ final BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
+ assertEquals(1, map.getNumDataPages());
+ map.free();
+ }
+
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 5cce41d5a7..09511ff35f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution;
import java.io.IOException;
+import com.google.common.annotations.VisibleForTesting;
+
import org.apache.spark.SparkEnv;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
@@ -220,6 +222,11 @@ public final class UnsafeFixedWidthAggregationMap {
return map.getPeakMemoryUsedBytes();
}
+ @VisibleForTesting
+ public int getNumDataPages() {
+ return map.getNumDataPages();
+ }
+
/**
* Free the memory associated with this map. This is idempotent and can be called multiple times.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 6b5935a7ce..c40ca97379 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -17,12 +17,13 @@
package org.apache.spark.sql.execution.aggregate
-import org.apache.spark.rdd.RDD
+import org.apache.spark.TaskContext
+import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
+import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -68,35 +69,56 @@ case class TungstenAggregate(
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numInputRows = longMetric("numInputRows")
val numOutputRows = longMetric("numOutputRows")
- child.execute().mapPartitions { iter =>
- val hasInput = iter.hasNext
- if (!hasInput && groupingExpressions.nonEmpty) {
- // This is a grouped aggregate and the input iterator is empty,
- // so return an empty iterator.
- Iterator.empty.asInstanceOf[Iterator[UnsafeRow]]
- } else {
- val aggregationIterator =
- new TungstenAggregationIterator(
- groupingExpressions,
- nonCompleteAggregateExpressions,
- completeAggregateExpressions,
- initialInputBufferOffset,
- resultExpressions,
- newMutableProjection,
- child.output,
- iter,
- testFallbackStartsAt,
- numInputRows,
- numOutputRows)
-
- if (!hasInput && groupingExpressions.isEmpty) {
+
+ /**
+ * Set up the underlying unsafe data structures used before computing the parent partition.
+ * This makes sure our iterator is not starved by other operators in the same task.
+ */
+ def preparePartition(): TungstenAggregationIterator = {
+ new TungstenAggregationIterator(
+ groupingExpressions,
+ nonCompleteAggregateExpressions,
+ completeAggregateExpressions,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection,
+ child.output,
+ testFallbackStartsAt,
+ numInputRows,
+ numOutputRows)
+ }
+
+ /** Compute a partition using the iterator already set up previously. */
+ def executePartition(
+ context: TaskContext,
+ partitionIndex: Int,
+ aggregationIterator: TungstenAggregationIterator,
+ parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = {
+ val hasInput = parentIterator.hasNext
+ if (!hasInput) {
+ // We're not using the underlying map, so we just can free it here
+ aggregationIterator.free()
+ if (groupingExpressions.isEmpty) {
numOutputRows += 1
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
} else {
- aggregationIterator
+ // This is a grouped aggregate and the input iterator is empty,
+ // so return an empty iterator.
+ Iterator[UnsafeRow]()
}
+ } else {
+ aggregationIterator.start(parentIterator)
+ aggregationIterator
}
}
+
+ // Note: we need to set up the iterator in each partition before computing the
+ // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747).
+ val resultRdd = {
+ new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator](
+ child.execute(), preparePartition, executePartition, preservesPartitioning = true)
+ }
+ resultRdd.asInstanceOf[RDD[InternalRow]]
}
override def simpleString: String = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 1f383dd044..af7e0fcedb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -72,8 +72,6 @@ import org.apache.spark.sql.types.StructType
* the function used to create mutable projections.
* @param originalInputAttributes
* attributes of representing input rows from `inputIter`.
- * @param inputIter
- * the iterator containing input [[UnsafeRow]]s.
*/
class TungstenAggregationIterator(
groupingExpressions: Seq[NamedExpression],
@@ -83,12 +81,14 @@ class TungstenAggregationIterator(
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow],
testFallbackStartsAt: Option[Int],
numInputRows: LongSQLMetric,
numOutputRows: LongSQLMetric)
extends Iterator[UnsafeRow] with Logging {
+ // The parent partition iterator, to be initialized later in `start`
+ private[this] var inputIter: Iterator[InternalRow] = null
+
///////////////////////////////////////////////////////////////////////////
// Part 1: Initializing aggregate functions.
///////////////////////////////////////////////////////////////////////////
@@ -348,11 +348,15 @@ class TungstenAggregationIterator(
false // disable tracking of performance metrics
)
+ // Exposed for testing
+ private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap
+
// The function used to read and process input rows. When processing input rows,
// it first uses hash-based aggregation by putting groups and their buffers in
// hashMap. If we could not allocate more memory for the map, we switch to
// sort-based aggregation (by calling switchToSortBasedAggregation).
private def processInputs(): Unit = {
+ assert(inputIter != null, "attempted to process input when iterator was null")
while (!sortBased && inputIter.hasNext) {
val newInput = inputIter.next()
numInputRows += 1
@@ -372,6 +376,7 @@ class TungstenAggregationIterator(
// that it switch to sort-based aggregation after `fallbackStartsAt` input rows have
// been processed.
private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = {
+ assert(inputIter != null, "attempted to process input when iterator was null")
var i = 0
while (!sortBased && inputIter.hasNext) {
val newInput = inputIter.next()
@@ -412,6 +417,7 @@ class TungstenAggregationIterator(
* Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
*/
private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = {
+ assert(inputIter != null, "attempted to process input when iterator was null")
logInfo("falling back to sort based aggregation.")
// Step 1: Get the ExternalSorter containing sorted entries of the map.
externalSorter = hashMap.destructAndCreateExternalSorter()
@@ -431,6 +437,11 @@ class TungstenAggregationIterator(
case _ => false
}
+ // Note: Since we spill the sorter's contents immediately after creating it, we must insert
+ // something into the sorter here to ensure that we acquire at least a page of memory.
+ // This is done through `externalSorter.insertKV`, which will trigger the page allocation.
+ // Otherwise, children operators may steal the window of opportunity and starve our sorter.
+
if (needsProcess) {
// First, we create a buffer.
val buffer = createNewAggregationBuffer()
@@ -588,27 +599,33 @@ class TungstenAggregationIterator(
// have not switched to sort-based aggregation.
///////////////////////////////////////////////////////////////////////////
- // Starts to process input rows.
- testFallbackStartsAt match {
- case None =>
- processInputs()
- case Some(fallbackStartsAt) =>
- // This is the testing path. processInputsWithControlledFallback is same as processInputs
- // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
- // have been processed.
- processInputsWithControlledFallback(fallbackStartsAt)
- }
+ /**
+ * Start processing input rows.
+ * Only after this method is called will this iterator be non-empty.
+ */
+ def start(parentIter: Iterator[InternalRow]): Unit = {
+ inputIter = parentIter
+ testFallbackStartsAt match {
+ case None =>
+ processInputs()
+ case Some(fallbackStartsAt) =>
+ // This is the testing path. processInputsWithControlledFallback is same as processInputs
+ // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
+ // have been processed.
+ processInputsWithControlledFallback(fallbackStartsAt)
+ }
- // If we did not switch to sort-based aggregation in processInputs,
- // we pre-load the first key-value pair from the map (to make hasNext idempotent).
- if (!sortBased) {
- // First, set aggregationBufferMapIterator.
- aggregationBufferMapIterator = hashMap.iterator()
- // Pre-load the first key-value pair from the aggregationBufferMapIterator.
- mapIteratorHasNext = aggregationBufferMapIterator.next()
- // If the map is empty, we just free it.
- if (!mapIteratorHasNext) {
- hashMap.free()
+ // If we did not switch to sort-based aggregation in processInputs,
+ // we pre-load the first key-value pair from the map (to make hasNext idempotent).
+ if (!sortBased) {
+ // First, set aggregationBufferMapIterator.
+ aggregationBufferMapIterator = hashMap.iterator()
+ // Pre-load the first key-value pair from the aggregationBufferMapIterator.
+ mapIteratorHasNext = aggregationBufferMapIterator.next()
+ // If the map is empty, we just free it.
+ if (!mapIteratorHasNext) {
+ hashMap.free()
+ }
}
}
@@ -673,21 +690,20 @@ class TungstenAggregationIterator(
}
///////////////////////////////////////////////////////////////////////////
- // Part 8: A utility function used to generate a output row when there is no
- // input and there is no grouping expression.
+ // Part 8: Utility functions
///////////////////////////////////////////////////////////////////////////
+ /**
+ * Generate a output row when there is no input and there is no grouping expression.
+ */
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
- if (groupingExpressions.isEmpty) {
- sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
- // We create a output row and copy it. So, we can free the map.
- val resultCopy =
- generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy()
- hashMap.free()
- resultCopy
- } else {
- throw new IllegalStateException(
- "This method should not be called when groupingExpressions is not empty.")
- }
+ assert(groupingExpressions.isEmpty)
+ assert(inputIter == null)
+ generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer)
+ }
+
+ /** Free memory used in the underlying map. */
+ def free(): Unit = {
+ hashMap.free()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
new file mode 100644
index 0000000000..ac22c2f3c0
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.unsafe.memory.TaskMemoryManager
+
+class TungstenAggregationIteratorSuite extends SparkFunSuite {
+
+ test("memory acquired on construction") {
+ // set up environment
+ val ctx = TestSQLContext
+
+ val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
+ TaskContext.setTaskContext(taskContext)
+
+ // Assert that a page is allocated before processing starts
+ var iter: TungstenAggregationIterator = null
+ try {
+ val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => {
+ () => new InterpretedMutableProjection(expr, schema)
+ }
+ val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy")
+ iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0,
+ Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
+ val numPages = iter.getHashMap.getNumDataPages
+ assert(numPages === 1)
+ } finally {
+ // Clean up
+ if (iter != null) {
+ iter.free()
+ }
+ TaskContext.unset()
+ }
+ }
+}