aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2015-08-03 14:22:07 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-08-03 14:22:07 -0700
commit702aa9d7fb16c98a50e046edfd76b8a7861d0391 (patch)
tree42f3a10ebd1086fda3d7b185b73cb452623964ca
parente4765a46833baff1dd7465c4cf50e947de7e8f21 (diff)
downloadspark-702aa9d7fb16c98a50e046edfd76b8a7861d0391.tar.gz
spark-702aa9d7fb16c98a50e046edfd76b8a7861d0391.tar.bz2
spark-702aa9d7fb16c98a50e046edfd76b8a7861d0391.zip
[SPARK-8735] [SQL] Expose memory usage for shuffles, joins and aggregations
This patch exposes the memory used by internal data structures on the SparkUI. This tracks memory used by all spilling operations and SQL operators backed by Tungsten, e.g. `BroadcastHashJoin`, `ExternalSort`, `GeneratedAggregate` etc. The metric exposed is "peak execution memory", which broadly refers to the peak in-memory sizes of each of these data structure. A separate patch will extend this by linking the new information to the SQL operators themselves. <img width="950" alt="screen shot 2015-07-29 at 7 43 17 pm" src="https://cloud.githubusercontent.com/assets/2133137/8974776/b90fc980-362a-11e5-9e2b-842da75b1641.png"> <img width="802" alt="screen shot 2015-07-29 at 7 43 05 pm" src="https://cloud.githubusercontent.com/assets/2133137/8974777/baa76492-362a-11e5-9b77-e364a6a6b64e.png"> <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/7770) <!-- Reviewable:end --> Author: Andrew Or <andrew@databricks.com> Closes #7770 from andrewor14/expose-memory-metrics and squashes the following commits: 9abecb9 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics f5b0d68 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics d7df332 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 8eefbc5 [Andrew Or] Fix non-failing tests 9de2a12 [Andrew Or] Fix tests due to another logical merge conflict 876bfa4 [Andrew Or] Fix failing test after logical merge conflict 361a359 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 40b4802 [Andrew Or] Fix style? d0fef87 [Andrew Or] Fix tests? b3b92f6 [Andrew Or] Address comments 0625d73 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics c00a197 [Andrew Or] Fix potential NPEs 10da1cd [Andrew Or] Fix compile 17f4c2d [Andrew Or] Fix compile? a87b4d0 [Andrew Or] Fix compile? d70874d [Andrew Or] Fix test compile + address comments 2840b7d [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 6aa2f7a [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics b889a68 [Andrew Or] Minor changes: comments, spacing, style 663a303 [Andrew Or] UnsafeShuffleWriter: update peak memory before close d090a94 [Andrew Or] Fix style 2480d84 [Andrew Or] Expand test coverage 5f1235b [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 1ecf678 [Andrew Or] Minor changes: comments, style, unused imports 0b6926c [Andrew Or] Oops 111a05e [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics a7a39a5 [Andrew Or] Strengthen presence check for accumulator a919eb7 [Andrew Or] Add tests for unsafe shuffle writer 23c845d [Andrew Or] Add tests for SQL operators a757550 [Andrew Or] Address comments b5c51c1 [Andrew Or] Re-enable test in JavaAPISuite 5107691 [Andrew Or] Add tests for internal accumulators 59231e4 [Andrew Or] Fix tests 9528d09 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 5b5e6f3 [Andrew Or] Add peak execution memory to summary table + tooltip 92b4b6b [Andrew Or] Display peak execution memory on the UI eee5437 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics d9b9015 [Andrew Or] Track execution memory in unsafe shuffles 770ee54 [Andrew Or] Track execution memory in broadcast joins 9c605a4 [Andrew Or] Track execution memory in GeneratedAggregate 9e824f2 [Andrew Or] Add back execution memory tracking for *ExternalSort 4ef4cb1 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics e6c3e2f [Andrew Or] Move internal accumulators creation to Stage a417592 [Andrew Or] Expose memory metrics in UnsafeExternalSorter 3c4f042 [Andrew Or] Track memory usage in ExternalAppendOnlyMap / ExternalSorter bd7ab3f [Andrew Or] Add internal accumulators to TaskContext
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java27
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java38
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java8
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java29
-rw-r--r--core/src/main/resources/org/apache/spark/ui/static/webui.css2
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulators.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/Aggregator.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/ui/ToolTips.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala140
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala20
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java3
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java54
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java39
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java46
-rw-r--r--core/src/test/scala/org/apache/spark/AccumulatorSuite.scala193
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala10
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala76
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala14
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java7
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala60
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala94
51 files changed, 1070 insertions, 163 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 1aa6ba4201..bf4eaa59ff 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.unsafe;
import java.io.File;
import java.io.IOException;
import java.util.LinkedList;
+import javax.annotation.Nullable;
import scala.Tuple2;
@@ -86,9 +87,12 @@ final class UnsafeShuffleExternalSorter {
private final LinkedList<SpillInfo> spills = new LinkedList<SpillInfo>();
+ /** Peak memory used by this sorter so far, in bytes. **/
+ private long peakMemoryUsedBytes;
+
// These variables are reset after spilling:
- private UnsafeShuffleInMemorySorter sorter;
- private MemoryBlock currentPage = null;
+ @Nullable private UnsafeShuffleInMemorySorter sorter;
+ @Nullable private MemoryBlock currentPage = null;
private long currentPagePosition = -1;
private long freeSpaceInCurrentPage = 0;
@@ -106,6 +110,7 @@ final class UnsafeShuffleExternalSorter {
this.blockManager = blockManager;
this.taskContext = taskContext;
this.initialSize = initialSize;
+ this.peakMemoryUsedBytes = initialSize;
this.numPartitions = numPartitions;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
@@ -279,10 +284,26 @@ final class UnsafeShuffleExternalSorter {
for (MemoryBlock page : allocatedPages) {
totalPageSize += page.size();
}
- return sorter.getMemoryUsage() + totalPageSize;
+ return ((sorter == null) ? 0 : sorter.getMemoryUsage()) + totalPageSize;
+ }
+
+ private void updatePeakMemoryUsed() {
+ long mem = getMemoryUsage();
+ if (mem > peakMemoryUsedBytes) {
+ peakMemoryUsedBytes = mem;
+ }
+ }
+
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ long getPeakMemoryUsedBytes() {
+ updatePeakMemoryUsed();
+ return peakMemoryUsedBytes;
}
private long freeMemory() {
+ updatePeakMemoryUsed();
long memoryFreed = 0;
for (MemoryBlock block : allocatedPages) {
memoryManager.freePage(block);
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
index d47d6fc9c2..6e2eeb37c8 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -27,6 +27,7 @@ import scala.Product2;
import scala.collection.JavaConversions;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
+import scala.collection.immutable.Map;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.ByteStreams;
@@ -78,8 +79,9 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final SparkConf sparkConf;
private final boolean transferToEnabled;
- private MapStatus mapStatus = null;
- private UnsafeShuffleExternalSorter sorter = null;
+ @Nullable private MapStatus mapStatus;
+ @Nullable private UnsafeShuffleExternalSorter sorter;
+ private long peakMemoryUsedBytes = 0;
/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
@@ -131,9 +133,28 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@VisibleForTesting
public int maxRecordSizeBytes() {
+ assert(sorter != null);
return sorter.maxRecordSizeBytes;
}
+ private void updatePeakMemoryUsed() {
+ // sorter can be null if this writer is closed
+ if (sorter != null) {
+ long mem = sorter.getPeakMemoryUsedBytes();
+ if (mem > peakMemoryUsedBytes) {
+ peakMemoryUsedBytes = mem;
+ }
+ }
+ }
+
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ public long getPeakMemoryUsedBytes() {
+ updatePeakMemoryUsed();
+ return peakMemoryUsedBytes;
+ }
+
/**
* This convenience method should only be called in test code.
*/
@@ -144,7 +165,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
- // Keep track of success so we know if we ecountered an exception
+ // Keep track of success so we know if we encountered an exception
// We do this rather than a standard try/catch/re-throw to handle
// generic throwables.
boolean success = false;
@@ -189,6 +210,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@VisibleForTesting
void closeAndWriteOutput() throws IOException {
+ assert(sorter != null);
+ updatePeakMemoryUsed();
serBuffer = null;
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
@@ -209,6 +232,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@VisibleForTesting
void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
+ assert(sorter != null);
final K key = record._1();
final int partitionId = partitioner.getPartition(key);
serBuffer.reset();
@@ -431,6 +455,14 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@Override
public Option<MapStatus> stop(boolean success) {
try {
+ // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite)
+ Map<String, Accumulator<Object>> internalAccumulators =
+ taskContext.internalMetricsToAccumulators();
+ if (internalAccumulators != null) {
+ internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY())
+ .add(getPeakMemoryUsedBytes());
+ }
+
if (stopping) {
return Option.apply(null);
} else {
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 01a66084e9..20347433e1 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -505,7 +505,7 @@ public final class BytesToBytesMap {
// Here, we'll copy the data into our data pages. Because we only store a relative offset from
// the key address instead of storing the absolute address of the value, the key and value
// must be stored in the same memory page.
- // (8 byte key length) (key) (8 byte value length) (value)
+ // (8 byte key length) (key) (value)
final long requiredSize = 8 + keyLengthBytes + valueLengthBytes;
// --- Figure out where to insert the new record ---------------------------------------------
@@ -655,7 +655,10 @@ public final class BytesToBytesMap {
return pageSizeBytes;
}
- /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
+ /**
+ * Returns the total amount of memory, in bytes, consumed by this map's managed structures.
+ * Note that this is also the peak memory used by this map, since the map is append-only.
+ */
public long getTotalMemoryConsumption() {
long totalDataPagesSize = 0L;
for (MemoryBlock dataPage : dataPages) {
@@ -674,7 +677,6 @@ public final class BytesToBytesMap {
return timeSpentResizingNs;
}
-
/**
* Returns the average number of probes per key lookup.
*/
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index b984301cbb..bf5f965a9d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -70,13 +70,14 @@ public final class UnsafeExternalSorter {
private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
// These variables are reset after spilling:
- private UnsafeInMemorySorter inMemSorter;
+ @Nullable private UnsafeInMemorySorter inMemSorter;
// Whether the in-mem sorter is created internally, or passed in from outside.
// If it is passed in from outside, we shouldn't release the in-mem sorter's memory.
private boolean isInMemSorterExternal = false;
private MemoryBlock currentPage = null;
private long currentPagePosition = -1;
private long freeSpaceInCurrentPage = 0;
+ private long peakMemoryUsedBytes = 0;
public static UnsafeExternalSorter createWithExistingInMemorySorter(
TaskMemoryManager taskMemoryManager,
@@ -183,6 +184,7 @@ public final class UnsafeExternalSorter {
* Sort and spill the current records in response to memory pressure.
*/
public void spill() throws IOException {
+ assert(inMemSorter != null);
logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
Thread.currentThread().getId(),
Utils.bytesToString(getMemoryUsage()),
@@ -219,7 +221,22 @@ public final class UnsafeExternalSorter {
for (MemoryBlock page : allocatedPages) {
totalPageSize += page.size();
}
- return inMemSorter.getMemoryUsage() + totalPageSize;
+ return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
+ }
+
+ private void updatePeakMemoryUsed() {
+ long mem = getMemoryUsage();
+ if (mem > peakMemoryUsedBytes) {
+ peakMemoryUsedBytes = mem;
+ }
+ }
+
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ public long getPeakMemoryUsedBytes() {
+ updatePeakMemoryUsed();
+ return peakMemoryUsedBytes;
}
@VisibleForTesting
@@ -233,6 +250,7 @@ public final class UnsafeExternalSorter {
* @return the number of bytes freed.
*/
public long freeMemory() {
+ updatePeakMemoryUsed();
long memoryFreed = 0;
for (MemoryBlock block : allocatedPages) {
taskMemoryManager.freePage(block);
@@ -277,7 +295,8 @@ public final class UnsafeExternalSorter {
* @return true if the record can be inserted without requiring more allocations, false otherwise.
*/
private boolean haveSpaceForRecord(int requiredSpace) {
- assert (requiredSpace > 0);
+ assert(requiredSpace > 0);
+ assert(inMemSorter != null);
return (inMemSorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
}
@@ -290,6 +309,7 @@ public final class UnsafeExternalSorter {
* the record size.
*/
private void allocateSpaceForRecord(int requiredSpace) throws IOException {
+ assert(inMemSorter != null);
// TODO: merge these steps to first calculate total memory requirements for this insert,
// then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
// data page.
@@ -350,6 +370,7 @@ public final class UnsafeExternalSorter {
if (!haveSpaceForRecord(totalSpaceRequired)) {
allocateSpaceForRecord(totalSpaceRequired);
}
+ assert(inMemSorter != null);
final long recordAddress =
taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
@@ -382,6 +403,7 @@ public final class UnsafeExternalSorter {
if (!haveSpaceForRecord(totalSpaceRequired)) {
allocateSpaceForRecord(totalSpaceRequired);
}
+ assert(inMemSorter != null);
final long recordAddress =
taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
@@ -405,6 +427,7 @@ public final class UnsafeExternalSorter {
}
public UnsafeSorterIterator getSortedIterator() throws IOException {
+ assert(inMemSorter != null);
final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator();
int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
if (spillWriters.isEmpty()) {
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index b1cef47042..648cd1b104 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -207,7 +207,7 @@ span.additional-metric-title {
/* Hide all additional metrics by default. This is done here rather than using JavaScript to
* avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */
.scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote,
-.serialization_time, .getting_result_time {
+.serialization_time, .getting_result_time, .peak_execution_memory {
display: none;
}
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index eb75f26718..b6a0119c69 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -152,8 +152,14 @@ class Accumulable[R, T] private[spark] (
in.defaultReadObject()
value_ = zero
deserialized = true
- val taskContext = TaskContext.get()
- taskContext.registerAccumulator(this)
+ // Automatically register the accumulator when it is deserialized with the task closure.
+ // Note that internal accumulators are deserialized before the TaskContext is created and
+ // are registered in the TaskContext constructor.
+ if (!isInternal) {
+ val taskContext = TaskContext.get()
+ assume(taskContext != null, "Task context was null when deserializing user accumulators")
+ taskContext.registerAccumulator(this)
+ }
}
override def toString: String = if (value_ == null) "null" else value_.toString
@@ -248,10 +254,20 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
* @param param helper object defining how to add elements of type `T`
* @tparam T result type
*/
-class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String])
- extends Accumulable[T, T](initialValue, param, name) {
+class Accumulator[T] private[spark] (
+ @transient initialValue: T,
+ param: AccumulatorParam[T],
+ name: Option[String],
+ internal: Boolean)
+ extends Accumulable[T, T](initialValue, param, name, internal) {
+
+ def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = {
+ this(initialValue, param, name, false)
+ }
- def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None)
+ def this(initialValue: T, param: AccumulatorParam[T]) = {
+ this(initialValue, param, None, false)
+ }
}
/**
@@ -342,3 +358,37 @@ private[spark] object Accumulators extends Logging {
}
}
+
+private[spark] object InternalAccumulator {
+ val PEAK_EXECUTION_MEMORY = "peakExecutionMemory"
+ val TEST_ACCUMULATOR = "testAccumulator"
+
+ // For testing only.
+ // This needs to be a def since we don't want to reuse the same accumulator across stages.
+ private def maybeTestAccumulator: Option[Accumulator[Long]] = {
+ if (sys.props.contains("spark.testing")) {
+ Some(new Accumulator(
+ 0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true))
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Accumulators for tracking internal metrics.
+ *
+ * These accumulators are created with the stage such that all tasks in the stage will
+ * add to the same set of accumulators. We do this to report the distribution of accumulator
+ * values across all tasks within each stage.
+ */
+ def create(): Seq[Accumulator[Long]] = {
+ Seq(
+ // Execution memory refers to the memory used by internal data structures created
+ // during shuffles, aggregations and joins. The value of this accumulator should be
+ // approximately the sum of the peak sizes across all such data structures created
+ // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort.
+ new Accumulator(
+ 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true)
+ ) ++ maybeTestAccumulator.toSeq
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index ceeb58075d..289aab9bd9 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -58,12 +58,7 @@ case class Aggregator[K, V, C] (
} else {
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
combiners.insertAll(iter)
- // Update task metrics if context is not null
- // TODO: Make context non optional in a future release
- Option(context).foreach { c =>
- c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
- c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
- }
+ updateMetrics(context, combiners)
combiners.iterator
}
}
@@ -89,13 +84,18 @@ case class Aggregator[K, V, C] (
} else {
val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
combiners.insertAll(iter)
- // Update task metrics if context is not null
- // TODO: Make context non-optional in a future release
- Option(context).foreach { c =>
- c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
- c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
- }
+ updateMetrics(context, combiners)
combiners.iterator
}
}
+
+ /** Update task metrics after populating the external map. */
+ private def updateMetrics(context: TaskContext, map: ExternalAppendOnlyMap[_, _, _]): Unit = {
+ Option(context).foreach { c =>
+ c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
+ c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
+ c.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes)
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 5d2c551d58..63cca80b2d 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -61,12 +61,12 @@ object TaskContext {
protected[spark] def unset(): Unit = taskContext.remove()
/**
- * Return an empty task context that is not actually used.
- * Internal use only.
+ * An empty task context that does not represent an actual task.
*/
- private[spark] def empty(): TaskContext = {
- new TaskContextImpl(0, 0, 0, 0, null, null)
+ private[spark] def empty(): TaskContextImpl = {
+ new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty)
}
+
}
@@ -187,4 +187,9 @@ abstract class TaskContext extends Serializable {
* accumulator id and the value of the Map is the latest accumulator local value.
*/
private[spark] def collectAccumulators(): Map[Long, Any]
+
+ /**
+ * Accumulators for tracking internal metrics indexed by the name.
+ */
+ private[spark] val internalMetricsToAccumulators: Map[String, Accumulator[Long]]
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 9ee168ae01..5df94c6d3a 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -32,6 +32,7 @@ private[spark] class TaskContextImpl(
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
@transient private val metricsSystem: MetricsSystem,
+ internalAccumulators: Seq[Accumulator[Long]],
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
@@ -114,4 +115,11 @@ private[spark] class TaskContextImpl(
private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized {
accumulators.mapValues(_.localValue).toMap
}
+
+ private[spark] override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = {
+ // Explicitly register internal accumulators here because these are
+ // not captured in the task closure and are already deserialized
+ internalAccumulators.foreach(registerAccumulator)
+ internalAccumulators.map { a => (a.name.get, a) }.toMap
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 130b58882d..9c617fc719 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -23,8 +23,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
-import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
+import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer}
import org.apache.spark.util.Utils
@@ -169,8 +168,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
for ((it, depNum) <- rddIterators) {
map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
}
- context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled)
- context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled)
+ context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
+ context.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes)
new InterruptibleIterator(context,
map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
index e0edd7d4ae..11d123eec4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
@@ -24,11 +24,12 @@ import org.apache.spark.annotation.DeveloperApi
* Information about an [[org.apache.spark.Accumulable]] modified during a task or stage.
*/
@DeveloperApi
-class AccumulableInfo (
+class AccumulableInfo private[spark] (
val id: Long,
val name: String,
val update: Option[String], // represents a partial update within a task
- val value: String) {
+ val value: String,
+ val internal: Boolean) {
override def equals(other: Any): Boolean = other match {
case acc: AccumulableInfo =>
@@ -40,10 +41,10 @@ class AccumulableInfo (
object AccumulableInfo {
def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = {
- new AccumulableInfo(id, name, update, value)
+ new AccumulableInfo(id, name, update, value, internal = false)
}
def apply(id: Long, name: String, value: String): AccumulableInfo = {
- new AccumulableInfo(id, name, None, value)
+ new AccumulableInfo(id, name, None, value, internal = false)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index c4fa277c21..bb489c6b6e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -773,16 +773,26 @@ class DAGScheduler(
stage.pendingTasks.clear()
// First figure out the indexes of partition ids to compute.
- val partitionsToCompute: Seq[Int] = {
+ val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = {
stage match {
case stage: ShuffleMapStage =>
- (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty)
+ val allPartitions = 0 until stage.numPartitions
+ val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty }
+ (allPartitions, filteredPartitions)
case stage: ResultStage =>
val job = stage.resultOfJob.get
- (0 until job.numPartitions).filter(id => !job.finished(id))
+ val allPartitions = 0 until job.numPartitions
+ val filteredPartitions = allPartitions.filter { id => !job.finished(id) }
+ (allPartitions, filteredPartitions)
}
}
+ // Reset internal accumulators only if this stage is not partially submitted
+ // Otherwise, we may override existing accumulator values from some tasks
+ if (allPartitions == partitionsToCompute) {
+ stage.resetInternalAccumulators()
+ }
+
val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull
runningStages += stage
@@ -852,7 +862,8 @@ class DAGScheduler(
partitionsToCompute.map { id =>
val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
- new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
+ new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
+ taskBinary, part, locs, stage.internalAccumulators)
}
case stage: ResultStage =>
@@ -861,7 +872,8 @@ class DAGScheduler(
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = taskIdToLocations(id)
- new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
+ new ResultTask(stage.id, stage.latestInfo.attemptId,
+ taskBinary, part, locs, id, stage.internalAccumulators)
}
}
} catch {
@@ -916,9 +928,11 @@ class DAGScheduler(
// To avoid UI cruft, ignore cases where value wasn't updated
if (acc.name.isDefined && partialValue != acc.zero) {
val name = acc.name.get
- stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}")
+ val value = s"${acc.value}"
+ stage.latestInfo.accumulables(id) =
+ new AccumulableInfo(id, name, None, value, acc.isInternal)
event.taskInfo.accumulables +=
- AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}")
+ new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal)
}
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 9c2606e278..c4dc080e2b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -45,8 +45,10 @@ private[spark] class ResultTask[T, U](
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient locs: Seq[TaskLocation],
- val outputId: Int)
- extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {
+ val outputId: Int,
+ internalAccumulators: Seq[Accumulator[Long]])
+ extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators)
+ with Serializable {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 14c8c00961..f478f9982a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -43,12 +43,14 @@ private[spark] class ShuffleMapTask(
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
- @transient private var locs: Seq[TaskLocation])
- extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {
+ @transient private var locs: Seq[TaskLocation],
+ internalAccumulators: Seq[Accumulator[Long]])
+ extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators)
+ with Logging {
/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
- this(0, 0, null, new Partition { override def index: Int = 0 }, null)
+ this(0, 0, null, new Partition { override def index: Int = 0 }, null, null)
}
@transient private val preferredLocs: Seq[TaskLocation] = {
@@ -69,7 +71,7 @@ private[spark] class ShuffleMapTask(
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
- return writer.stop(success = true).get
+ writer.stop(success = true).get
} catch {
case e: Exception =>
try {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 40a333a3e0..de05ee256d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -68,6 +68,22 @@ private[spark] abstract class Stage(
val name = callSite.shortForm
val details = callSite.longForm
+ private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty
+
+ /** Internal accumulators shared across all tasks in this stage. */
+ def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators
+
+ /**
+ * Re-initialize the internal accumulators associated with this stage.
+ *
+ * This is called every time the stage is submitted, *except* when a subset of tasks
+ * belonging to this stage has already finished. Otherwise, reinitializing the internal
+ * accumulators here again will override partial values from the finished tasks.
+ */
+ def resetInternalAccumulators(): Unit = {
+ _internalAccumulators = InternalAccumulator.create()
+ }
+
/**
* Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized
* here, before any attempts have actually been created, because the DAGScheduler uses this
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 1978305cfe..9edf9f048f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -23,7 +23,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.HashMap
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext}
+import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.unsafe.memory.TaskMemoryManager
@@ -47,7 +47,8 @@ import org.apache.spark.util.Utils
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
- var partitionId: Int) extends Serializable {
+ val partitionId: Int,
+ internalAccumulators: Seq[Accumulator[Long]]) extends Serializable {
/**
* The key of the Map is the accumulator id and the value of the Map is the latest accumulator
@@ -68,12 +69,13 @@ private[spark] abstract class Task[T](
metricsSystem: MetricsSystem)
: (T, AccumulatorUpdates) = {
context = new TaskContextImpl(
- stageId = stageId,
- partitionId = partitionId,
- taskAttemptId = taskAttemptId,
- attemptNumber = attemptNumber,
- taskMemoryManager = taskMemoryManager,
- metricsSystem = metricsSystem,
+ stageId,
+ partitionId,
+ taskAttemptId,
+ attemptNumber,
+ taskMemoryManager,
+ metricsSystem,
+ internalAccumulators,
runningLocally = false)
TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index de79fa56f0..0c8f08f0f3 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -17,7 +17,7 @@
package org.apache.spark.shuffle.hash
-import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
@@ -100,8 +100,10 @@ private[spark] class HashShuffleReader[K, C](
// the ExternalSorter won't spill to disk.
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
- context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)
- context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
sorter.iterator
case None =>
aggregatedIter
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index e2d25e3636..cb122eaed8 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -62,6 +62,13 @@ private[spark] object ToolTips {
"""Time that the executor spent paused for Java garbage collection while the task was
running."""
+ val PEAK_EXECUTION_MEMORY =
+ """Execution memory refers to the memory used by internal data structures created during
+ shuffles, aggregations and joins when Tungsten is enabled. The value of this accumulator
+ should be approximately the sum of the peak sizes across all such data structures created
+ in this task. For SQL jobs, this only tracks all unsafe operators, broadcast joins, and
+ external sort."""
+
val JOB_TIMELINE =
"""Shows when jobs started and ended and when executors joined or left. Drag to scroll.
Click Enable Zooming and use mouse wheel to zoom in/out."""
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index cf04b5e592..3954c3d1ef 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -26,6 +26,7 @@ import scala.xml.{Elem, Node, Unparsed}
import org.apache.commons.lang3.StringEscapeUtils
+import org.apache.spark.{InternalAccumulator, SparkConf}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
import org.apache.spark.ui._
@@ -67,6 +68,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
// if we find that it's okay.
private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000)
+ private val displayPeakExecutionMemory =
+ parent.conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean)
def render(request: HttpServletRequest): Seq[Node] = {
progressListener.synchronized {
@@ -114,10 +117,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val stageData = stageDataOption.get
val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime)
-
val numCompleted = tasks.count(_.taskInfo.finished)
- val accumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables
- val hasAccumulators = accumulables.size > 0
+
+ val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables
+ val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal }
+ val hasAccumulators = externalAccumulables.size > 0
val summary =
<div>
@@ -221,6 +225,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
<span class="additional-metric-title">Getting Result Time</span>
</span>
</li>
+ {if (displayPeakExecutionMemory) {
+ <li>
+ <span data-toggle="tooltip"
+ title={ToolTips.PEAK_EXECUTION_MEMORY} data-placement="right">
+ <input type="checkbox" name={TaskDetailsClassNames.PEAK_EXECUTION_MEMORY}/>
+ <span class="additional-metric-title">Peak Execution Memory</span>
+ </span>
+ </li>
+ }}
</ul>
</div>
</div>
@@ -241,11 +254,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val accumulableTable = UIUtils.listingTable(
accumulableHeaders,
accumulableRow,
- accumulables.values.toSeq)
+ externalAccumulables.toSeq)
val currentTime = System.currentTimeMillis()
val (taskTable, taskTableHTML) = try {
val _taskTable = new TaskPagedTable(
+ parent.conf,
UIUtils.prependBaseUri(parent.basePath) +
s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}",
tasks,
@@ -294,12 +308,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
else {
def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] =
Distribution(data).get.getQuantiles()
-
def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = {
getDistributionQuantiles(times).map { millis =>
<td>{UIUtils.formatDuration(millis.toLong)}</td>
}
}
+ def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = {
+ getDistributionQuantiles(data).map(d => <td>{Utils.bytesToString(d.toLong)}</td>)
+ }
val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
metrics.get.executorDeserializeTime.toDouble
@@ -349,6 +365,23 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
</span>
</td> +:
getFormattedTimeQuantiles(gettingResultTimes)
+
+ val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) =>
+ info.accumulables
+ .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY }
+ .map { acc => acc.value.toLong }
+ .getOrElse(0L)
+ .toDouble
+ }
+ val peakExecutionMemoryQuantiles = {
+ <td>
+ <span data-toggle="tooltip"
+ title={ToolTips.PEAK_EXECUTION_MEMORY} data-placement="right">
+ Peak Execution Memory
+ </span>
+ </td> +: getFormattedSizeQuantiles(peakExecutionMemory)
+ }
+
// The scheduler delay includes the network delay to send the task to the worker
// machine and to send back the result (but not the time to fetch the task result,
// if it needed to be fetched from the block manager on the worker).
@@ -359,10 +392,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
title={ToolTips.SCHEDULER_DELAY} data-placement="right">Scheduler Delay</span></td>
val schedulerDelayQuantiles = schedulerDelayTitle +:
getFormattedTimeQuantiles(schedulerDelays)
-
- def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] =
- getDistributionQuantiles(data).map(d => <td>{Utils.bytesToString(d.toLong)}</td>)
-
def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double])
: Seq[Elem] = {
val recordDist = getDistributionQuantiles(records).iterator
@@ -466,6 +495,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
{serializationQuantiles}
</tr>,
<tr class={TaskDetailsClassNames.GETTING_RESULT_TIME}>{gettingResultQuantiles}</tr>,
+ if (displayPeakExecutionMemory) {
+ <tr class={TaskDetailsClassNames.PEAK_EXECUTION_MEMORY}>
+ {peakExecutionMemoryQuantiles}
+ </tr>
+ } else {
+ Nil
+ },
if (stageData.hasInput) <tr>{inputQuantiles}</tr> else Nil,
if (stageData.hasOutput) <tr>{outputQuantiles}</tr> else Nil,
if (stageData.hasShuffleRead) {
@@ -499,7 +535,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val executorTable = new ExecutorTable(stageId, stageAttemptId, parent)
val maybeAccumulableTable: Seq[Node] =
- if (accumulables.size > 0) { <h4>Accumulators</h4> ++ accumulableTable } else Seq()
+ if (hasAccumulators) { <h4>Accumulators</h4> ++ accumulableTable } else Seq()
val content =
summary ++
@@ -750,29 +786,30 @@ private[ui] case class TaskTableRowBytesSpilledData(
* Contains all data that needs for sorting and generating HTML. Using this one rather than
* TaskUIData to avoid creating duplicate contents during sorting the data.
*/
-private[ui] case class TaskTableRowData(
- index: Int,
- taskId: Long,
- attempt: Int,
- speculative: Boolean,
- status: String,
- taskLocality: String,
- executorIdAndHost: String,
- launchTime: Long,
- duration: Long,
- formatDuration: String,
- schedulerDelay: Long,
- taskDeserializationTime: Long,
- gcTime: Long,
- serializationTime: Long,
- gettingResultTime: Long,
- accumulators: Option[String], // HTML
- input: Option[TaskTableRowInputData],
- output: Option[TaskTableRowOutputData],
- shuffleRead: Option[TaskTableRowShuffleReadData],
- shuffleWrite: Option[TaskTableRowShuffleWriteData],
- bytesSpilled: Option[TaskTableRowBytesSpilledData],
- error: String)
+private[ui] class TaskTableRowData(
+ val index: Int,
+ val taskId: Long,
+ val attempt: Int,
+ val speculative: Boolean,
+ val status: String,
+ val taskLocality: String,
+ val executorIdAndHost: String,
+ val launchTime: Long,
+ val duration: Long,
+ val formatDuration: String,
+ val schedulerDelay: Long,
+ val taskDeserializationTime: Long,
+ val gcTime: Long,
+ val serializationTime: Long,
+ val gettingResultTime: Long,
+ val peakExecutionMemoryUsed: Long,
+ val accumulators: Option[String], // HTML
+ val input: Option[TaskTableRowInputData],
+ val output: Option[TaskTableRowOutputData],
+ val shuffleRead: Option[TaskTableRowShuffleReadData],
+ val shuffleWrite: Option[TaskTableRowShuffleWriteData],
+ val bytesSpilled: Option[TaskTableRowBytesSpilledData],
+ val error: String)
private[ui] class TaskDataSource(
tasks: Seq[TaskUIData],
@@ -816,10 +853,15 @@ private[ui] class TaskDataSource(
val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
val gettingResultTime = getGettingResultTime(info, currentTime)
- val maybeAccumulators = info.accumulables
- val accumulatorsReadable = maybeAccumulators.map { acc =>
+ val (taskInternalAccumulables, taskExternalAccumulables) =
+ info.accumulables.partition(_.internal)
+ val externalAccumulableReadable = taskExternalAccumulables.map { acc =>
StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")
}
+ val peakExecutionMemoryUsed = taskInternalAccumulables
+ .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY }
+ .map { acc => acc.value.toLong }
+ .getOrElse(0L)
val maybeInput = metrics.flatMap(_.inputMetrics)
val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L)
@@ -923,7 +965,7 @@ private[ui] class TaskDataSource(
None
}
- TaskTableRowData(
+ new TaskTableRowData(
info.index,
info.taskId,
info.attempt,
@@ -939,7 +981,8 @@ private[ui] class TaskDataSource(
gcTime,
serializationTime,
gettingResultTime,
- if (hasAccumulators) Some(accumulatorsReadable.mkString("<br/>")) else None,
+ peakExecutionMemoryUsed,
+ if (hasAccumulators) Some(externalAccumulableReadable.mkString("<br/>")) else None,
input,
output,
shuffleRead,
@@ -1006,6 +1049,10 @@ private[ui] class TaskDataSource(
override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime)
}
+ case "Peak Execution Memory" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.peakExecutionMemoryUsed, y.peakExecutionMemoryUsed)
+ }
case "Accumulators" =>
if (hasAccumulators) {
new Ordering[TaskTableRowData] {
@@ -1132,6 +1179,7 @@ private[ui] class TaskDataSource(
}
private[ui] class TaskPagedTable(
+ conf: SparkConf,
basePath: String,
data: Seq[TaskUIData],
hasAccumulators: Boolean,
@@ -1143,7 +1191,11 @@ private[ui] class TaskPagedTable(
currentTime: Long,
pageSize: Int,
sortColumn: String,
- desc: Boolean) extends PagedTable[TaskTableRowData]{
+ desc: Boolean) extends PagedTable[TaskTableRowData] {
+
+ // We only track peak memory used for unsafe operators
+ private val displayPeakExecutionMemory =
+ conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean)
override def tableId: String = ""
@@ -1195,6 +1247,13 @@ private[ui] class TaskPagedTable(
("GC Time", ""),
("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME),
("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++
+ {
+ if (displayPeakExecutionMemory) {
+ Seq(("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY))
+ } else {
+ Nil
+ }
+ } ++
{if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++
{if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++
{if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++
@@ -1271,6 +1330,11 @@ private[ui] class TaskPagedTable(
<td class={TaskDetailsClassNames.GETTING_RESULT_TIME}>
{UIUtils.formatDuration(task.gettingResultTime)}
</td>
+ {if (displayPeakExecutionMemory) {
+ <td class={TaskDetailsClassNames.PEAK_EXECUTION_MEMORY}>
+ {Utils.bytesToString(task.peakExecutionMemoryUsed)}
+ </td>
+ }}
{if (task.accumulators.nonEmpty) {
<td>{Unparsed(task.accumulators.get)}</td>
}}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
index 9bf67db8ac..d2dfc5a329 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
@@ -31,4 +31,5 @@ private[spark] object TaskDetailsClassNames {
val SHUFFLE_READ_REMOTE_SIZE = "shuffle_read_remote"
val RESULT_SERIALIZATION_TIME = "serialization_time"
val GETTING_RESULT_TIME = "getting_result_time"
+ val PEAK_EXECUTION_MEMORY = "peak_execution_memory"
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index d166037351..f929b12606 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -89,6 +89,7 @@ class ExternalAppendOnlyMap[K, V, C](
// Number of bytes spilled in total
private var _diskBytesSpilled = 0L
+ def diskBytesSpilled: Long = _diskBytesSpilled
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val fileBufferSize =
@@ -97,6 +98,10 @@ class ExternalAppendOnlyMap[K, V, C](
// Write metrics for current spill
private var curWriteMetrics: ShuffleWriteMetrics = _
+ // Peak size of the in-memory map observed so far, in bytes
+ private var _peakMemoryUsedBytes: Long = 0L
+ def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes
+
private val keyComparator = new HashComparator[K]
private val ser = serializer.newInstance()
@@ -126,7 +131,11 @@ class ExternalAppendOnlyMap[K, V, C](
while (entries.hasNext) {
curEntry = entries.next()
- if (maybeSpill(currentMap, currentMap.estimateSize())) {
+ val estimatedSize = currentMap.estimateSize()
+ if (estimatedSize > _peakMemoryUsedBytes) {
+ _peakMemoryUsedBytes = estimatedSize
+ }
+ if (maybeSpill(currentMap, estimatedSize)) {
currentMap = new SizeTrackingAppendOnlyMap[K, C]
}
currentMap.changeValue(curEntry._1, update)
@@ -207,8 +216,6 @@ class ExternalAppendOnlyMap[K, V, C](
spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
}
- def diskBytesSpilled: Long = _diskBytesSpilled
-
/**
* Return an iterator that merges the in-memory map with the spilled maps.
* If no spill has occurred, simply return the in-memory map's iterator.
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index ba7ec834d6..19287edbaf 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -152,6 +152,9 @@ private[spark] class ExternalSorter[K, V, C](
private var _diskBytesSpilled = 0L
def diskBytesSpilled: Long = _diskBytesSpilled
+ // Peak size of the in-memory data structure observed so far, in bytes
+ private var _peakMemoryUsedBytes: Long = 0L
+ def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes
// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
// Can be a partial ordering by hash code if a total ordering is not provided through by the
@@ -224,15 +227,22 @@ private[spark] class ExternalSorter[K, V, C](
return
}
+ var estimatedSize = 0L
if (usingMap) {
- if (maybeSpill(map, map.estimateSize())) {
+ estimatedSize = map.estimateSize()
+ if (maybeSpill(map, estimatedSize)) {
map = new PartitionedAppendOnlyMap[K, C]
}
} else {
- if (maybeSpill(buffer, buffer.estimateSize())) {
+ estimatedSize = buffer.estimateSize()
+ if (maybeSpill(buffer, estimatedSize)) {
buffer = newBuffer()
}
}
+
+ if (estimatedSize > _peakMemoryUsedBytes) {
+ _peakMemoryUsedBytes = estimatedSize
+ }
}
/**
@@ -684,8 +694,10 @@ private[spark] class ExternalSorter[K, V, C](
}
}
- context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
- context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
+ context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
+ context.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes)
lengths
}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index e948ca3347..ffe4b4baff 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -51,7 +51,6 @@ import org.junit.Test;
import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.*;
-import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.partial.BoundedDouble;
import org.apache.spark.partial.PartialResult;
@@ -1011,7 +1010,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, null, false, new TaskMetrics());
+ TaskContext context = TaskContext$.MODULE$.empty();
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
index 04fc09b323..98c32bbc29 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -190,6 +190,7 @@ public class UnsafeShuffleWriterSuite {
});
when(taskContext.taskMetrics()).thenReturn(taskMetrics);
+ when(taskContext.internalMetricsToAccumulators()).thenReturn(null);
when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
@@ -542,4 +543,57 @@ public class UnsafeShuffleWriterSuite {
writer.stop(false);
assertSpillFilesWereCleanedUp();
}
+
+ @Test
+ public void testPeakMemoryUsed() throws Exception {
+ final long recordLengthBytes = 8;
+ final long pageSizeBytes = 256;
+ final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
+ final SparkConf conf = new SparkConf().set("spark.buffer.pageSize", pageSizeBytes + "b");
+ final UnsafeShuffleWriter<Object, Object> writer =
+ new UnsafeShuffleWriter<Object, Object>(
+ blockManager,
+ shuffleBlockResolver,
+ taskMemoryManager,
+ shuffleMemoryManager,
+ new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
+ 0, // map id
+ taskContext,
+ conf);
+
+ // Peak memory should be monotonically increasing. More specifically, every time
+ // we allocate a new page it should increase by exactly the size of the page.
+ long previousPeakMemory = writer.getPeakMemoryUsedBytes();
+ long newPeakMemory;
+ try {
+ for (int i = 0; i < numRecordsPerPage * 10; i++) {
+ writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+ newPeakMemory = writer.getPeakMemoryUsedBytes();
+ if (i % numRecordsPerPage == 0) {
+ // We allocated a new page for this record, so peak memory should change
+ assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
+ } else {
+ assertEquals(previousPeakMemory, newPeakMemory);
+ }
+ previousPeakMemory = newPeakMemory;
+ }
+
+ // Spilling should not change peak memory
+ writer.forceSorterToSpill();
+ newPeakMemory = writer.getPeakMemoryUsedBytes();
+ assertEquals(previousPeakMemory, newPeakMemory);
+ for (int i = 0; i < numRecordsPerPage; i++) {
+ writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+ }
+ newPeakMemory = writer.getPeakMemoryUsedBytes();
+ assertEquals(previousPeakMemory, newPeakMemory);
+
+ // Closing the writer should not change peak memory
+ writer.closeAndWriteOutput();
+ newPeakMemory = writer.getPeakMemoryUsedBytes();
+ assertEquals(previousPeakMemory, newPeakMemory);
+ } finally {
+ writer.stop(false);
+ }
+ }
}
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index dbb7c662d7..0e23a64fb7 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -25,6 +25,7 @@ import org.junit.*;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import static org.hamcrest.Matchers.greaterThan;
+import static org.junit.Assert.*;
import static org.mockito.AdditionalMatchers.geq;
import static org.mockito.Mockito.*;
@@ -495,4 +496,42 @@ public abstract class AbstractBytesToBytesMapSuite {
map.growAndRehash();
map.free();
}
+
+ @Test
+ public void testTotalMemoryConsumption() {
+ final long recordLengthBytes = 24;
+ final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker
+ final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes;
+ final BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, 1024, pageSizeBytes);
+
+ // Since BytesToBytesMap is append-only, we expect the total memory consumption to be
+ // monotonically increasing. More specifically, every time we allocate a new page it
+ // should increase by exactly the size of the page. In this regard, the memory usage
+ // at any given time is also the peak memory used.
+ long previousMemory = map.getTotalMemoryConsumption();
+ long newMemory;
+ try {
+ for (long i = 0; i < numRecordsPerPage * 10; i++) {
+ final long[] value = new long[]{i};
+ map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8).putNewKey(
+ value,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 8,
+ value,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 8);
+ newMemory = map.getTotalMemoryConsumption();
+ if (i % numRecordsPerPage == 0) {
+ // We allocated a new page for this record, so peak memory should change
+ assertEquals(previousMemory + pageSizeBytes, newMemory);
+ } else {
+ assertEquals(previousMemory, newMemory);
+ }
+ previousMemory = newMemory;
+ }
+ } finally {
+ map.free();
+ }
+ }
}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 52fa8bcd57..c11949d57a 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -247,4 +247,50 @@ public class UnsafeExternalSorterSuite {
assertSpillFilesWereCleanedUp();
}
+ @Test
+ public void testPeakMemoryUsed() throws Exception {
+ final long recordLengthBytes = 8;
+ final long pageSizeBytes = 256;
+ final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
+ final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
+ taskMemoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ 1024,
+ pageSizeBytes);
+
+ // Peak memory should be monotonically increasing. More specifically, every time
+ // we allocate a new page it should increase by exactly the size of the page.
+ long previousPeakMemory = sorter.getPeakMemoryUsedBytes();
+ long newPeakMemory;
+ try {
+ for (int i = 0; i < numRecordsPerPage * 10; i++) {
+ insertNumber(sorter, i);
+ newPeakMemory = sorter.getPeakMemoryUsedBytes();
+ if (i % numRecordsPerPage == 0) {
+ // We allocated a new page for this record, so peak memory should change
+ assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
+ } else {
+ assertEquals(previousPeakMemory, newPeakMemory);
+ }
+ previousPeakMemory = newPeakMemory;
+ }
+
+ // Spilling should not change peak memory
+ sorter.spill();
+ newPeakMemory = sorter.getPeakMemoryUsedBytes();
+ assertEquals(previousPeakMemory, newPeakMemory);
+ for (int i = 0; i < numRecordsPerPage; i++) {
+ insertNumber(sorter, i);
+ }
+ newPeakMemory = sorter.getPeakMemoryUsedBytes();
+ assertEquals(previousPeakMemory, newPeakMemory);
+ } finally {
+ sorter.freeMemory();
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index e942d6579b..48f549575f 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -18,13 +18,17 @@
package org.apache.spark
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
import scala.ref.WeakReference
import org.scalatest.Matchers
+import org.scalatest.exceptions.TestFailedException
+import org.apache.spark.scheduler._
-class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext {
+class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext {
+ import InternalAccumulator._
implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] =
new AccumulableParam[mutable.Set[A], A] {
@@ -155,4 +159,191 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
assert(!Accumulators.originals.get(accId).isDefined)
}
+ test("internal accumulators in TaskContext") {
+ val accums = InternalAccumulator.create()
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums)
+ val internalMetricsToAccums = taskContext.internalMetricsToAccumulators
+ val collectedInternalAccums = taskContext.collectInternalAccumulators()
+ val collectedAccums = taskContext.collectAccumulators()
+ assert(internalMetricsToAccums.size > 0)
+ assert(internalMetricsToAccums.values.forall(_.isInternal))
+ assert(internalMetricsToAccums.contains(TEST_ACCUMULATOR))
+ val testAccum = internalMetricsToAccums(TEST_ACCUMULATOR)
+ assert(collectedInternalAccums.size === internalMetricsToAccums.size)
+ assert(collectedInternalAccums.size === collectedAccums.size)
+ assert(collectedInternalAccums.contains(testAccum.id))
+ assert(collectedAccums.contains(testAccum.id))
+ }
+
+ test("internal accumulators in a stage") {
+ val listener = new SaveInfoListener
+ val numPartitions = 10
+ sc = new SparkContext("local", "test")
+ sc.addSparkListener(listener)
+ // Have each task add 1 to the internal accumulator
+ sc.parallelize(1 to 100, numPartitions).mapPartitions { iter =>
+ TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
+ iter
+ }.count()
+ val stageInfos = listener.getCompletedStageInfos
+ val taskInfos = listener.getCompletedTaskInfos
+ assert(stageInfos.size === 1)
+ assert(taskInfos.size === numPartitions)
+ // The accumulator values should be merged in the stage
+ val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR)
+ assert(stageAccum.value.toLong === numPartitions)
+ // The accumulator should be updated locally on each task
+ val taskAccumValues = taskInfos.map { taskInfo =>
+ val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR)
+ assert(taskAccum.update.isDefined)
+ assert(taskAccum.update.get.toLong === 1)
+ taskAccum.value.toLong
+ }
+ // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions
+ assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
+ }
+
+ test("internal accumulators in multiple stages") {
+ val listener = new SaveInfoListener
+ val numPartitions = 10
+ sc = new SparkContext("local", "test")
+ sc.addSparkListener(listener)
+ // Each stage creates its own set of internal accumulators so the
+ // values for the same metric should not be mixed up across stages
+ sc.parallelize(1 to 100, numPartitions)
+ .map { i => (i, i) }
+ .mapPartitions { iter =>
+ TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
+ iter
+ }
+ .reduceByKey { case (x, y) => x + y }
+ .mapPartitions { iter =>
+ TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 10
+ iter
+ }
+ .repartition(numPartitions * 2)
+ .mapPartitions { iter =>
+ TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100
+ iter
+ }
+ .count()
+ // We ran 3 stages, and the accumulator values should be distinct
+ val stageInfos = listener.getCompletedStageInfos
+ assert(stageInfos.size === 3)
+ val firstStageAccum = findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR)
+ val secondStageAccum = findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR)
+ val thirdStageAccum = findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR)
+ assert(firstStageAccum.value.toLong === numPartitions)
+ assert(secondStageAccum.value.toLong === numPartitions * 10)
+ assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100)
+ }
+
+ test("internal accumulators in fully resubmitted stages") {
+ testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks
+ }
+
+ test("internal accumulators in partially resubmitted stages") {
+ testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset
+ }
+
+ /**
+ * Return the accumulable info that matches the specified name.
+ */
+ private def findAccumulableInfo(
+ accums: Iterable[AccumulableInfo],
+ name: String): AccumulableInfo = {
+ accums.find { a => a.name == name }.getOrElse {
+ throw new TestFailedException(s"internal accumulator '$name' not found", 0)
+ }
+ }
+
+ /**
+ * Test whether internal accumulators are merged properly if some tasks fail.
+ */
+ private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = {
+ val listener = new SaveInfoListener
+ val numPartitions = 10
+ val numFailedPartitions = (0 until numPartitions).count(failCondition)
+ // This says use 1 core and retry tasks up to 2 times
+ sc = new SparkContext("local[1, 2]", "test")
+ sc.addSparkListener(listener)
+ sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) =>
+ val taskContext = TaskContext.get()
+ taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
+ // Fail the first attempts of a subset of the tasks
+ if (failCondition(i) && taskContext.attemptNumber() == 0) {
+ throw new Exception("Failing a task intentionally.")
+ }
+ iter
+ }.count()
+ val stageInfos = listener.getCompletedStageInfos
+ val taskInfos = listener.getCompletedTaskInfos
+ assert(stageInfos.size === 1)
+ assert(taskInfos.size === numPartitions + numFailedPartitions)
+ val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR)
+ // We should not double count values in the merged accumulator
+ assert(stageAccum.value.toLong === numPartitions)
+ val taskAccumValues = taskInfos.flatMap { taskInfo =>
+ if (!taskInfo.failed) {
+ // If a task succeeded, its update value should always be 1
+ val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR)
+ assert(taskAccum.update.isDefined)
+ assert(taskAccum.update.get.toLong === 1)
+ Some(taskAccum.value.toLong)
+ } else {
+ // If a task failed, we should not get its accumulator values
+ assert(taskInfo.accumulables.isEmpty)
+ None
+ }
+ }
+ assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
+ }
+
+}
+
+private[spark] object AccumulatorSuite {
+
+ /**
+ * Run one or more Spark jobs and verify that the peak execution memory accumulator
+ * is updated afterwards.
+ */
+ def verifyPeakExecutionMemorySet(
+ sc: SparkContext,
+ testName: String)(testBody: => Unit): Unit = {
+ val listener = new SaveInfoListener
+ sc.addSparkListener(listener)
+ // Verify that the accumulator does not already exist
+ sc.parallelize(1 to 10).count()
+ val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values)
+ assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY))
+ testBody
+ // Verify that peak execution memory is updated
+ val accum = listener.getCompletedStageInfos
+ .flatMap(_.accumulables.values)
+ .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)
+ .getOrElse {
+ throw new TestFailedException(
+ s"peak execution memory accumulator not set in '$testName'", 0)
+ }
+ assert(accum.value.toLong > 0)
+ }
+}
+
+/**
+ * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs.
+ */
+private class SaveInfoListener extends SparkListener {
+ private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo]
+ private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo]
+
+ def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq
+ def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq
+
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
+ completedStageInfos += stageCompleted.stageInfo
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+ completedTaskInfos += taskEnd.taskInfo
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 618a5fb247..cb8bd04e49 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -21,7 +21,7 @@ import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter
import org.scalatest.mock.MockitoSugar
-import org.apache.spark.executor.DataReadMethod
+import org.apache.spark.executor.{DataReadMethod, TaskMetrics}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
@@ -65,7 +65,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
// in blockManager.put is a losing battle. You have been warned.
blockManager = sc.env.blockManager
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+ val context = TaskContext.empty()
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
assert(computeValue.toList === List(1, 2, 3, 4))
@@ -77,7 +77,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12)
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result))
- val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+ val context = TaskContext.empty()
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
@@ -86,14 +86,14 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
// Local computation should not persist the resulting value, so don't expect a put().
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None)
- val context = new TaskContextImpl(0, 0, 0, 0, null, null, true)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty, runningLocally = true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
test("verify task metrics updated correctly") {
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+ val context = TaskContext.empty()
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index 3e8816a4c6..5f73ec8675 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -175,7 +175,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
- val tContext = new TaskContextImpl(0, 0, 0, 0, null, null)
+ val tContext = TaskContext.empty()
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index b3ca150195..f7e16af9d3 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -19,9 +19,11 @@ package org.apache.spark.scheduler
import org.apache.spark.TaskContext
-class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) {
+class FakeTask(
+ stageId: Int,
+ prefLocs: Seq[TaskLocation] = Nil)
+ extends Task[Int](stageId, 0, 0, Seq.empty) {
override def runTask(context: TaskContext): Int = 0
-
override def preferredLocations: Seq[TaskLocation] = prefLocs
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
index 383855caef..f333247924 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
@@ -25,7 +25,7 @@ import org.apache.spark.TaskContext
* A Task implementation that fails to serialize.
*/
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
- extends Task[Array[Byte]](stageId, 0, 0) {
+ extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 9201d1e1f3..450ab7b9fe 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -57,8 +57,9 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
- val task = new ResultTask[String, String](0, 0,
- sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
+ val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array)
+ val task = new ResultTask[String, String](
+ 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty)
intercept[RuntimeException] {
task.run(0, 0, null)
}
@@ -66,7 +67,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}
test("all TaskCompletionListeners should be called even if some fail") {
- val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+ val context = TaskContext.empty()
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("blah"))
context.addTaskCompletionListener(listener)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 3abb99c4b2..f7cc4bb61d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -136,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
/**
* A Task implementation that results in a large serialized task.
*/
-class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) {
+class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024)
val random = new Random(0)
random.nextBytes(randomBuffer)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
index db718ecabb..05b3afef5b 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
@@ -138,7 +138,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {
shuffleHandle,
reduceId,
reduceId + 1,
- new TaskContextImpl(0, 0, 0, 0, null, null),
+ TaskContext.empty(),
blockManager,
mapOutputTracker)
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index cf8bd8ae69..828153bdbf 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -29,7 +29,7 @@ import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.PrivateMethodTester
-import org.apache.spark.{SparkFunSuite, TaskContextImpl}
+import org.apache.spark.{SparkFunSuite, TaskContext}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.shuffle.BlockFetchingListener
@@ -95,7 +95,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0, 0, null, null),
+ TaskContext.empty(),
transfer,
blockManager,
blocksByAddress,
@@ -165,7 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null)
+ val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
@@ -227,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null)
+ val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
new file mode 100644
index 0000000000..98f9314f31
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS}
+
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite, Success}
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler._
+import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab}
+import org.apache.spark.ui.scope.RDDOperationGraphListener
+
+class StagePageSuite extends SparkFunSuite with LocalSparkContext {
+
+ test("peak execution memory only displayed if unsafe is enabled") {
+ val unsafeConf = "spark.sql.unsafe.enabled"
+ val conf = new SparkConf().set(unsafeConf, "true")
+ val html = renderStagePage(conf).toString().toLowerCase
+ val targetString = "peak execution memory"
+ assert(html.contains(targetString))
+ // Disable unsafe and make sure it's not there
+ val conf2 = new SparkConf().set(unsafeConf, "false")
+ val html2 = renderStagePage(conf2).toString().toLowerCase
+ assert(!html2.contains(targetString))
+ }
+
+ /**
+ * Render a stage page started with the given conf and return the HTML.
+ * This also runs a dummy stage to populate the page with useful content.
+ */
+ private def renderStagePage(conf: SparkConf): Seq[Node] = {
+ val jobListener = new JobProgressListener(conf)
+ val graphListener = new RDDOperationGraphListener(conf)
+ val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
+ val request = mock(classOf[HttpServletRequest])
+ when(tab.conf).thenReturn(conf)
+ when(tab.progressListener).thenReturn(jobListener)
+ when(tab.operationGraphListener).thenReturn(graphListener)
+ when(tab.appName).thenReturn("testing")
+ when(tab.headerTabs).thenReturn(Seq.empty)
+ when(request.getParameter("id")).thenReturn("0")
+ when(request.getParameter("attempt")).thenReturn("0")
+ val page = new StagePage(tab)
+
+ // Simulate a stage in job progress listener
+ val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details")
+ val taskInfo = new TaskInfo(0, 0, 0, 0, "0", "localhost", TaskLocality.ANY, false)
+ jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
+ jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
+ taskInfo.markSuccessful()
+ jobListener.onTaskEnd(
+ SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty))
+ jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
+ page.render(request)
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 9c362f0de7..12e9bafcc9 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -399,4 +399,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
sc.stop()
}
+ test("external aggregation updates peak execution memory") {
+ val conf = createSparkConf(loadDefaults = false)
+ .set("spark.shuffle.memoryFraction", "0.001")
+ .set("spark.shuffle.manager", "hash") // make sure we're not also using ExternalSorter
+ sc = new SparkContext("local", "test", conf)
+ // No spilling
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map without spilling") {
+ sc.parallelize(1 to 10, 2).map { i => (i, i) }.reduceByKey(_ + _).count()
+ }
+ // With spilling
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map with spilling") {
+ sc.parallelize(1 to 1000 * 1000, 2).map { i => (i, i) }.reduceByKey(_ + _).count()
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 986cd8623d..bdb0f4d507 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -692,7 +692,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
sortWithoutBreakingSortingContracts(createSparkConf(true, false))
}
- def sortWithoutBreakingSortingContracts(conf: SparkConf) {
+ private def sortWithoutBreakingSortingContracts(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.01")
conf.set("spark.shuffle.manager", "sort")
sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
@@ -743,5 +743,15 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
}
sorter2.stop()
- }
+ }
+
+ test("sorting updates peak execution memory") {
+ val conf = createSparkConf(loadDefaults = false, kryo = false)
+ .set("spark.shuffle.manager", "sort")
+ sc = new SparkContext("local", "test", conf)
+ // Avoid aggregating here to make sure we're not also using ExternalAppendOnlyMap
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter") {
+ sc.parallelize(1 to 1000, 2).repartition(100).count()
+ }
+ }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 5e4c6232c9..193906d247 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -106,6 +106,13 @@ final class UnsafeExternalRowSorter {
sorter.spill();
}
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ public long getPeakMemoryUsage() {
+ return sorter.getPeakMemoryUsedBytes();
+ }
+
private void cleanupResources() {
sorter.freeMemory();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 9e2c9334a7..43d06ce9bd 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -209,6 +209,14 @@ public final class UnsafeFixedWidthAggregationMap {
}
/**
+ * The memory used by this map's managed structures, in bytes.
+ * Note that this is also the peak memory used by this map, since the map is append-only.
+ */
+ public long getMemoryUsage() {
+ return map.getTotalMemoryConsumption();
+ }
+
+ /**
* Free the memory associated with this map. This is idempotent and can be called multiple times.
*/
public void free() {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index cd87b8deba..bf4905dc1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import java.io.IOException
-import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -263,11 +263,12 @@ case class GeneratedAggregate(
assert(iter.hasNext, "There should be at least one row for this path")
log.info("Using Unsafe-based aggregator")
val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
+ val taskContext = TaskContext.get()
val aggregationMap = new UnsafeFixedWidthAggregationMap(
newAggregationBuffer(EmptyRow),
aggregationBufferSchema,
groupKeySchema,
- TaskContext.get.taskMemoryManager(),
+ taskContext.taskMemoryManager(),
SparkEnv.get.shuffleMemoryManager,
1024 * 16, // initial capacity
pageSizeBytes,
@@ -284,6 +285,10 @@ case class GeneratedAggregate(
updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
}
+ // Record memory used in the process
+ taskContext.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(aggregationMap.getMemoryUsage)
+
new Iterator[InternalRow] {
private[this] val mapIterator = aggregationMap.iterator()
private[this] val resultProjection = resultProjectionBuilder()
@@ -300,7 +305,7 @@ case class GeneratedAggregate(
} else {
// This is the last element in the iterator, so let's free the buffer. Before we do,
// though, we need to make a defensive copy of the result so that we don't return an
- // object that might contain dangling pointers to the freed memory
+ // object that might contain dangling pointers to the freed memory.
val resultCopy = result.copy()
aggregationMap.free()
resultCopy
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 624efc1b1d..e73e2523a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
import scala.concurrent._
import scala.concurrent.duration._
+import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -70,7 +71,14 @@ case class BroadcastHashJoin(
val broadcastRelation = Await.result(broadcastFuture, timeout)
streamedPlan.execute().mapPartitions { streamedIter =>
- hashJoin(streamedIter, broadcastRelation.value)
+ val hashedRelation = broadcastRelation.value
+ hashedRelation match {
+ case unsafe: UnsafeHashedRelation =>
+ TaskContext.get().internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+ case _ =>
+ }
+ hashJoin(streamedIter, hashedRelation)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index 309716a0ef..c35e439cc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
import scala.concurrent._
import scala.concurrent.duration._
+import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -75,6 +76,13 @@ case class BroadcastHashOuterJoin(
val hashTable = broadcastRelation.value
val keyGenerator = streamedKeyGenerator
+ hashTable match {
+ case unsafe: UnsafeHashedRelation =>
+ TaskContext.get().internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+ case _ =>
+ }
+
joinType match {
case LeftOuter =>
streamedIter.flatMap(currentRow => {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index a60593911f..5bd06fbdca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.joins
+import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -51,7 +52,14 @@ case class BroadcastLeftSemiJoinHash(
val broadcastedRelation = sparkContext.broadcast(hashRelation)
left.execute().mapPartitions { streamIter =>
- hashSemiJoin(streamIter, broadcastedRelation.value)
+ val hashedRelation = broadcastedRelation.value
+ hashedRelation match {
+ case unsafe: UnsafeHashedRelation =>
+ TaskContext.get().internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+ case _ =>
+ }
+ hashSemiJoin(streamIter, hashedRelation)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index cc8bbfd2f8..58b4236f7b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -183,8 +183,27 @@ private[joins] final class UnsafeHashedRelation(
private[joins] def this() = this(null) // Needed for serialization
// Use BytesToBytesMap in executor for better performance (it's created when deserialization)
+ // This is used in broadcast joins and distributed mode only
@transient private[this] var binaryMap: BytesToBytesMap = _
+ /**
+ * Return the size of the unsafe map on the executors.
+ *
+ * For broadcast joins, this hashed relation is bigger on the driver because it is
+ * represented as a Java hash map there. While serializing the map to the executors,
+ * however, we rehash the contents in a binary map to reduce the memory footprint on
+ * the executors.
+ *
+ * For non-broadcast joins or in local mode, return 0.
+ */
+ def getUnsafeSize: Long = {
+ if (binaryMap != null) {
+ binaryMap.getTotalMemoryConsumption
+ } else {
+ 0
+ }
+ }
+
override def get(key: InternalRow): Seq[InternalRow] = {
val unsafeKey = key.asInstanceOf[UnsafeRow]
@@ -214,7 +233,7 @@ private[joins] final class UnsafeHashedRelation(
}
} else {
- // Use the JavaHashMap in Local mode or ShuffleHashJoin
+ // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin)
hashTable.get(unsafeKey)
}
}
@@ -316,6 +335,7 @@ private[joins] object UnsafeHashedRelation {
keyGenerator: UnsafeProjection,
sizeEstimate: Int): HashedRelation = {
+ // Use a Java hash table here because unsafe maps expect fixed size records
val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
// Create a mapping of buildKeys -> rows
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index 92cf328c76..3192b6ebe9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
@@ -76,6 +77,11 @@ case class ExternalSort(
val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering))
sorter.insertAll(iterator.map(r => (r.copy(), null)))
val baseIterator = sorter.iterator.map(_._1)
+ val context = TaskContext.get()
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
// TODO(marmbrus): The complex type signature below thwarts inference for no reason.
CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
}, preservesPartitioning = true)
@@ -137,7 +143,11 @@ case class TungstenSort(
if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
}
- sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+ val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+ val taskContext = TaskContext.get()
+ taskContext.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
+ sortedIterator
}, preservesPartitioning = true)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index f1abae0720..29dfcf2575 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -21,6 +21,7 @@ import java.sql.Timestamp
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.errors.DialectException
@@ -258,6 +259,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
}
}
+ private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = {
+ val df = sql(sqlText)
+ // First, check if we have GeneratedAggregate.
+ val hasGeneratedAgg = df.queryExecution.executedPlan
+ .collect { case _: GeneratedAggregate | _: aggregate.Aggregate => true }
+ .nonEmpty
+ if (!hasGeneratedAgg) {
+ fail(
+ s"""
+ |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan.
+ |${df.queryExecution.simpleString}
+ """.stripMargin)
+ }
+ // Then, check results.
+ checkAnswer(df, expectedResults)
+ }
+
test("aggregation with codegen") {
val originalValue = sqlContext.conf.codegenEnabled
sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
@@ -267,26 +285,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
.unionAll(sqlContext.table("testData"))
.registerTempTable("testData3x")
- def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = {
- val df = sql(sqlText)
- // First, check if we have GeneratedAggregate.
- var hasGeneratedAgg = false
- df.queryExecution.executedPlan.foreach {
- case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
- case newAggregate: aggregate.Aggregate => hasGeneratedAgg = true
- case _ =>
- }
- if (!hasGeneratedAgg) {
- fail(
- s"""
- |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan.
- |${df.queryExecution.simpleString}
- """.stripMargin)
- }
- // Then, check results.
- checkAnswer(df, expectedResults)
- }
-
try {
// Just to group rows.
testCodeGen(
@@ -1605,6 +1603,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123))))
}
+ test("aggregation with codegen updates peak execution memory") {
+ withSQLConf(
+ (SQLConf.CODEGEN_ENABLED.key, "true"),
+ (SQLConf.USE_SQL_AGGREGATE2.key, "false")) {
+ val sc = sqlContext.sparkContext
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") {
+ testCodeGen(
+ "SELECT key, count(value) FROM testData GROUP BY key",
+ (1 to 100).map(i => Row(i, 1)))
+ }
+ }
+ }
+
+ test("external sorting updates peak execution memory") {
+ withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) {
+ val sc = sqlContext.sparkContext
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sort") {
+ sortTest()
+ }
+ }
+ }
+
test("SPARK-9511: error with table starting with number") {
val df = sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString))
.toDF("num", "str")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
index c794984851..88bce0e319 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -21,6 +21,7 @@ import scala.util.Random
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.test.TestSQLContext
@@ -59,6 +60,17 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
)
}
+ test("sorting updates peak execution memory") {
+ val sc = TestSQLContext.sparkContext
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") {
+ checkThatPlansAgree(
+ (1 to 100).map(v => Tuple1(v)).toDF("a"),
+ (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child),
+ (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child),
+ sortAnswers = false)
+ }
+ }
+
// Test sorting on different data types
for (
dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 7c591f6143..ef827b0fe9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -69,7 +69,8 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
taskAttemptId = Random.nextInt(10000),
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
- metricsSystem = null))
+ metricsSystem = null,
+ internalAccumulators = Seq.empty))
try {
f
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 0282b25b9d..601a5a07ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -76,7 +76,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite {
taskAttemptId = 98456,
attemptNumber = 0,
taskMemoryManager = taskMemMgr,
- metricsSystem = null))
+ metricsSystem = null,
+ internalAccumulators = Seq.empty))
// Create the data converters
val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
new file mode 100644
index 0000000000..0554e11d25
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -0,0 +1,94 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+// TODO: uncomment the test here! It is currently failing due to
+// bad interaction with org.apache.spark.sql.test.TestSQLContext.
+
+// scalastyle:off
+//package org.apache.spark.sql.execution.joins
+//
+//import scala.reflect.ClassTag
+//
+//import org.scalatest.BeforeAndAfterAll
+//
+//import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
+//import org.apache.spark.sql.functions._
+//import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest}
+//
+///**
+// * Test various broadcast join operators with unsafe enabled.
+// *
+// * This needs to be its own suite because [[org.apache.spark.sql.test.TestSQLContext]] runs
+// * in local mode, but for tests in this suite we need to run Spark in local-cluster mode.
+// * In particular, the use of [[org.apache.spark.unsafe.map.BytesToBytesMap]] in
+// * [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered without
+// * serializing the hashed relation, which does not happen in local mode.
+// */
+//class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
+// private var sc: SparkContext = null
+// private var sqlContext: SQLContext = null
+//
+// /**
+// * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled.
+// */
+// override def beforeAll(): Unit = {
+// super.beforeAll()
+// val conf = new SparkConf()
+// .setMaster("local-cluster[2,1,1024]")
+// .setAppName("testing")
+// sc = new SparkContext(conf)
+// sqlContext = new SQLContext(sc)
+// sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true)
+// sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
+// }
+//
+// override def afterAll(): Unit = {
+// sc.stop()
+// sc = null
+// sqlContext = null
+// }
+//
+// /**
+// * Test whether the specified broadcast join updates the peak execution memory accumulator.
+// */
+// private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = {
+// AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) {
+// val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
+// val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
+// // Comparison at the end is for broadcast left semi join
+// val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
+// val df3 = df1.join(broadcast(df2), joinExpression, joinType)
+// val plan = df3.queryExecution.executedPlan
+// assert(plan.collect { case p: T => p }.size === 1)
+// plan.executeCollect()
+// }
+// }
+//
+// test("unsafe broadcast hash join updates peak execution memory") {
+// testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner")
+// }
+//
+// test("unsafe broadcast hash outer join updates peak execution memory") {
+// testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer")
+// }
+//
+// test("unsafe broadcast left semi join updates peak execution memory") {
+// testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi")
+// }
+//
+//}
+// scalastyle:on