aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2015-08-07 14:20:13 -0700
committerYin Huai <yhuai@databricks.com>2015-08-07 14:20:13 -0700
commit881548ab20fa4c4b635c51d956b14bd13981e2f4 (patch)
treec29a335f8f90d664a77bab926d866468922c762b
parent05d04e10a8ea030bea840c3c5ba93ecac479a039 (diff)
downloadspark-881548ab20fa4c4b635c51d956b14bd13981e2f4.tar.gz
spark-881548ab20fa4c4b635c51d956b14bd13981e2f4.tar.bz2
spark-881548ab20fa4c4b635c51d956b14bd13981e2f4.zip
[SPARK-9674] Re-enable ignored test in SQLQuerySuite
The original code that this test tests is removed in https://github.com/apache/spark/commit/9270bd06fd0b16892e3f37213b5bc7813ea11fdd. It was ignored shortly before that so we never caught it. This patch re-enables the test and adds the code necessary to make it pass. JoshRosen yhuai Author: Andrew Or <andrew@databricks.com> Closes #8015 from andrewor14/SPARK-9674 and squashes the following commits: 225eac2 [Andrew Or] Merge branch 'master' of github.com:apache/spark into SPARK-9674 8c24209 [Andrew Or] Fix NPE e541d64 [Andrew Or] Track aggregation memory for both sort and hash 0be3a42 [Andrew Or] Fix test
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java37
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java20
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java7
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala32
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala8
6 files changed, 85 insertions, 26 deletions
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 0636ae7c8d..7f79cd13aa 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -109,7 +109,7 @@ public final class BytesToBytesMap {
* Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i},
* while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode.
*/
- private LongArray longArray;
+ @Nullable private LongArray longArray;
// TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode
// and exploit word-alignment to use fewer bits to hold the address. This might let us store
// only one long per map entry, increasing the chance that this array will fit in cache at the
@@ -124,7 +124,7 @@ public final class BytesToBytesMap {
* A {@link BitSet} used to track location of the map where the key is set.
* Size of the bitset should be half of the size of the long array.
*/
- private BitSet bitset;
+ @Nullable private BitSet bitset;
private final double loadFactor;
@@ -166,6 +166,8 @@ public final class BytesToBytesMap {
private long numHashCollisions = 0;
+ private long peakMemoryUsedBytes = 0L;
+
public BytesToBytesMap(
TaskMemoryManager taskMemoryManager,
ShuffleMemoryManager shuffleMemoryManager,
@@ -321,6 +323,9 @@ public final class BytesToBytesMap {
Object keyBaseObject,
long keyBaseOffset,
int keyRowLengthBytes) {
+ assert(bitset != null);
+ assert(longArray != null);
+
if (enablePerfMetrics) {
numKeyLookups++;
}
@@ -410,6 +415,7 @@ public final class BytesToBytesMap {
}
private Location with(int pos, int keyHashcode, boolean isDefined) {
+ assert(longArray != null);
this.pos = pos;
this.isDefined = isDefined;
this.keyHashcode = keyHashcode;
@@ -525,6 +531,9 @@ public final class BytesToBytesMap {
assert (!isDefined) : "Can only set value once for a key";
assert (keyLengthBytes % 8 == 0);
assert (valueLengthBytes % 8 == 0);
+ assert(bitset != null);
+ assert(longArray != null);
+
if (numElements == MAX_CAPACITY) {
throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
}
@@ -658,6 +667,7 @@ public final class BytesToBytesMap {
* This method is idempotent and can be called multiple times.
*/
public void free() {
+ updatePeakMemoryUsed();
longArray = null;
bitset = null;
Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
@@ -684,14 +694,30 @@ public final class BytesToBytesMap {
/**
* Returns the total amount of memory, in bytes, consumed by this map's managed structures.
- * Note that this is also the peak memory used by this map, since the map is append-only.
*/
public long getTotalMemoryConsumption() {
long totalDataPagesSize = 0L;
for (MemoryBlock dataPage : dataPages) {
totalDataPagesSize += dataPage.size();
}
- return totalDataPagesSize + bitset.memoryBlock().size() + longArray.memoryBlock().size();
+ return totalDataPagesSize +
+ ((bitset != null) ? bitset.memoryBlock().size() : 0L) +
+ ((longArray != null) ? longArray.memoryBlock().size() : 0L);
+ }
+
+ private void updatePeakMemoryUsed() {
+ long mem = getTotalMemoryConsumption();
+ if (mem > peakMemoryUsedBytes) {
+ peakMemoryUsedBytes = mem;
+ }
+ }
+
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ public long getPeakMemoryUsedBytes() {
+ updatePeakMemoryUsed();
+ return peakMemoryUsedBytes;
}
/**
@@ -731,6 +757,9 @@ public final class BytesToBytesMap {
*/
@VisibleForTesting
void growAndRehash() {
+ assert(bitset != null);
+ assert(longArray != null);
+
long resizeStartTime = -1;
if (enablePerfMetrics) {
resizeStartTime = System.nanoTime();
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 0b11562980..e56a3f0b6d 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -525,7 +525,7 @@ public abstract class AbstractBytesToBytesMapSuite {
}
@Test
- public void testTotalMemoryConsumption() {
+ public void testPeakMemoryUsed() {
final long recordLengthBytes = 24;
final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker
final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes;
@@ -536,8 +536,8 @@ public abstract class AbstractBytesToBytesMapSuite {
// monotonically increasing. More specifically, every time we allocate a new page it
// should increase by exactly the size of the page. In this regard, the memory usage
// at any given time is also the peak memory used.
- long previousMemory = map.getTotalMemoryConsumption();
- long newMemory;
+ long previousPeakMemory = map.getPeakMemoryUsedBytes();
+ long newPeakMemory;
try {
for (long i = 0; i < numRecordsPerPage * 10; i++) {
final long[] value = new long[]{i};
@@ -548,15 +548,21 @@ public abstract class AbstractBytesToBytesMapSuite {
value,
PlatformDependent.LONG_ARRAY_OFFSET,
8);
- newMemory = map.getTotalMemoryConsumption();
+ newPeakMemory = map.getPeakMemoryUsedBytes();
if (i % numRecordsPerPage == 0) {
// We allocated a new page for this record, so peak memory should change
- assertEquals(previousMemory + pageSizeBytes, newMemory);
+ assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
} else {
- assertEquals(previousMemory, newMemory);
+ assertEquals(previousPeakMemory, newPeakMemory);
}
- previousMemory = newMemory;
+ previousPeakMemory = newPeakMemory;
}
+
+ // Freeing the map should not change the peak memory
+ map.free();
+ newPeakMemory = map.getPeakMemoryUsedBytes();
+ assertEquals(previousPeakMemory, newPeakMemory);
+
} finally {
map.free();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index efb33530da..b08a4a13a2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -210,11 +210,10 @@ public final class UnsafeFixedWidthAggregationMap {
}
/**
- * The memory used by this map's managed structures, in bytes.
- * Note that this is also the peak memory used by this map, since the map is append-only.
+ * Return the peak memory used so far, in bytes.
*/
- public long getMemoryUsage() {
- return map.getTotalMemoryConsumption();
+ public long getPeakMemoryUsedBytes() {
+ return map.getPeakMemoryUsedBytes();
}
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 9a65c9d3a4..69d6784713 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -160,6 +160,13 @@ public final class UnsafeKVExternalSorter {
}
/**
+ * Return the peak memory used so far, in bytes.
+ */
+ public long getPeakMemoryUsedBytes() {
+ return sorter.getPeakMemoryUsedBytes();
+ }
+
+ /**
* Marks the current page as no-more-space-available, and as a result, either allocate a
* new page or spill when we see the next record.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 4d5e98a3e9..440bef32f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.aggregate
import org.apache.spark.unsafe.KVIterator
-import org.apache.spark.{Logging, SparkEnv, TaskContext}
+import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
@@ -397,14 +397,20 @@ class TungstenAggregationIterator(
private[this] var mapIteratorHasNext: Boolean = false
///////////////////////////////////////////////////////////////////////////
- // Part 4: The function used to switch this iterator from hash-based
- // aggregation to sort-based aggregation.
+ // Part 3: Methods and fields used by sort-based aggregation.
///////////////////////////////////////////////////////////////////////////
+ // This sorter is used for sort-based aggregation. It is initialized as soon as
+ // we switch from hash-based to sort-based aggregation. Otherwise, it is not used.
+ private[this] var externalSorter: UnsafeKVExternalSorter = null
+
+ /**
+ * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
+ */
private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = {
logInfo("falling back to sort based aggregation.")
// Step 1: Get the ExternalSorter containing sorted entries of the map.
- val externalSorter: UnsafeKVExternalSorter = hashMap.destructAndCreateExternalSorter()
+ externalSorter = hashMap.destructAndCreateExternalSorter()
// Step 2: Free the memory used by the map.
hashMap.free()
@@ -601,7 +607,7 @@ class TungstenAggregationIterator(
}
///////////////////////////////////////////////////////////////////////////
- // Par 7: Iterator's public methods.
+ // Part 7: Iterator's public methods.
///////////////////////////////////////////////////////////////////////////
override final def hasNext: Boolean = {
@@ -610,7 +616,7 @@ class TungstenAggregationIterator(
override final def next(): UnsafeRow = {
if (hasNext) {
- if (sortBased) {
+ val res = if (sortBased) {
// Process the current group.
processCurrentSortedGroup()
// Generate output row for the current group.
@@ -641,6 +647,19 @@ class TungstenAggregationIterator(
result
}
}
+
+ // If this is the last record, update the task's peak memory usage. Since we destroy
+ // the map to create the sorter, their memory usages should not overlap, so it is safe
+ // to just use the max of the two.
+ if (!hasNext) {
+ val mapMemory = hashMap.getPeakMemoryUsedBytes
+ val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
+ val peakMemory = Math.max(mapMemory, sorterMemory)
+ TaskContext.get().internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory)
+ }
+
+ res
} else {
// no more result
throw new NoSuchElementException
@@ -651,6 +670,7 @@ class TungstenAggregationIterator(
// Part 8: A utility function used to generate a output row when there is no
// input and there is no grouping expression.
///////////////////////////////////////////////////////////////////////////
+
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
if (groupingExpressions.isEmpty) {
sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index c64aa7a07d..b14ef9bab9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -267,7 +267,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
if (!hasGeneratedAgg) {
fail(
s"""
- |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan.
+ |Codegen is enabled, but query $sqlText does not have TungstenAggregate in the plan.
|${df.queryExecution.simpleString}
""".stripMargin)
}
@@ -1602,10 +1602,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123))))
}
- ignore("aggregation with codegen updates peak execution memory") {
- withSQLConf(
- (SQLConf.CODEGEN_ENABLED.key, "true"),
- (SQLConf.USE_SQL_AGGREGATE2.key, "false")) {
+ test("aggregation with codegen updates peak execution memory") {
+ withSQLConf((SQLConf.CODEGEN_ENABLED.key, "true")) {
val sc = sqlContext.sparkContext
AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") {
testCodeGen(