aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java27
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java38
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java8
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java29
-rw-r--r--core/src/main/resources/org/apache/spark/ui/static/webui.css2
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulators.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/Aggregator.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/ui/ToolTips.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala140
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala20
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java3
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java54
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java39
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java46
-rw-r--r--core/src/test/scala/org/apache/spark/AccumulatorSuite.scala193
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala10
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala76
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala14
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java7
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala60
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala94
51 files changed, 1070 insertions, 163 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 1aa6ba4201..bf4eaa59ff 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.unsafe;
import java.io.File;
import java.io.IOException;
import java.util.LinkedList;
+import javax.annotation.Nullable;
import scala.Tuple2;
@@ -86,9 +87,12 @@ final class UnsafeShuffleExternalSorter {
private final LinkedList<SpillInfo> spills = new LinkedList<SpillInfo>();
+ /** Peak memory used by this sorter so far, in bytes. **/
+ private long peakMemoryUsedBytes;
+
// These variables are reset after spilling:
- private UnsafeShuffleInMemorySorter sorter;
- private MemoryBlock currentPage = null;
+ @Nullable private UnsafeShuffleInMemorySorter sorter;
+ @Nullable private MemoryBlock currentPage = null;
private long currentPagePosition = -1;
private long freeSpaceInCurrentPage = 0;
@@ -106,6 +110,7 @@ final class UnsafeShuffleExternalSorter {
this.blockManager = blockManager;
this.taskContext = taskContext;
this.initialSize = initialSize;
+ this.peakMemoryUsedBytes = initialSize;
this.numPartitions = numPartitions;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
@@ -279,10 +284,26 @@ final class UnsafeShuffleExternalSorter {
for (MemoryBlock page : allocatedPages) {
totalPageSize += page.size();
}
- return sorter.getMemoryUsage() + totalPageSize;
+ return ((sorter == null) ? 0 : sorter.getMemoryUsage()) + totalPageSize;
+ }
+
+ private void updatePeakMemoryUsed() {
+ long mem = getMemoryUsage();
+ if (mem > peakMemoryUsedBytes) {
+ peakMemoryUsedBytes = mem;
+ }
+ }
+
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ long getPeakMemoryUsedBytes() {
+ updatePeakMemoryUsed();
+ return peakMemoryUsedBytes;
}
private long freeMemory() {
+ updatePeakMemoryUsed();
long memoryFreed = 0;
for (MemoryBlock block : allocatedPages) {
memoryManager.freePage(block);
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
index d47d6fc9c2..6e2eeb37c8 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -27,6 +27,7 @@ import scala.Product2;
import scala.collection.JavaConversions;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
+import scala.collection.immutable.Map;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.ByteStreams;
@@ -78,8 +79,9 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final SparkConf sparkConf;
private final boolean transferToEnabled;
- private MapStatus mapStatus = null;
- private UnsafeShuffleExternalSorter sorter = null;
+ @Nullable private MapStatus mapStatus;
+ @Nullable private UnsafeShuffleExternalSorter sorter;
+ private long peakMemoryUsedBytes = 0;
/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
@@ -131,9 +133,28 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@VisibleForTesting
public int maxRecordSizeBytes() {
+ assert(sorter != null);
return sorter.maxRecordSizeBytes;
}
+ private void updatePeakMemoryUsed() {
+ // sorter can be null if this writer is closed
+ if (sorter != null) {
+ long mem = sorter.getPeakMemoryUsedBytes();
+ if (mem > peakMemoryUsedBytes) {
+ peakMemoryUsedBytes = mem;
+ }
+ }
+ }
+
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ public long getPeakMemoryUsedBytes() {
+ updatePeakMemoryUsed();
+ return peakMemoryUsedBytes;
+ }
+
/**
* This convenience method should only be called in test code.
*/
@@ -144,7 +165,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
- // Keep track of success so we know if we ecountered an exception
+ // Keep track of success so we know if we encountered an exception
// We do this rather than a standard try/catch/re-throw to handle
// generic throwables.
boolean success = false;
@@ -189,6 +210,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@VisibleForTesting
void closeAndWriteOutput() throws IOException {
+ assert(sorter != null);
+ updatePeakMemoryUsed();
serBuffer = null;
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
@@ -209,6 +232,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@VisibleForTesting
void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
+ assert(sorter != null);
final K key = record._1();
final int partitionId = partitioner.getPartition(key);
serBuffer.reset();
@@ -431,6 +455,14 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@Override
public Option<MapStatus> stop(boolean success) {
try {
+ // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite)
+ Map<String, Accumulator<Object>> internalAccumulators =
+ taskContext.internalMetricsToAccumulators();
+ if (internalAccumulators != null) {
+ internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY())
+ .add(getPeakMemoryUsedBytes());
+ }
+
if (stopping) {
return Option.apply(null);
} else {
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 01a66084e9..20347433e1 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -505,7 +505,7 @@ public final class BytesToBytesMap {
// Here, we'll copy the data into our data pages. Because we only store a relative offset from
// the key address instead of storing the absolute address of the value, the key and value
// must be stored in the same memory page.
- // (8 byte key length) (key) (8 byte value length) (value)
+ // (8 byte key length) (key) (value)
final long requiredSize = 8 + keyLengthBytes + valueLengthBytes;
// --- Figure out where to insert the new record ---------------------------------------------
@@ -655,7 +655,10 @@ public final class BytesToBytesMap {
return pageSizeBytes;
}
- /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
+ /**
+ * Returns the total amount of memory, in bytes, consumed by this map's managed structures.
+ * Note that this is also the peak memory used by this map, since the map is append-only.
+ */
public long getTotalMemoryConsumption() {
long totalDataPagesSize = 0L;
for (MemoryBlock dataPage : dataPages) {
@@ -674,7 +677,6 @@ public final class BytesToBytesMap {
return timeSpentResizingNs;
}
-
/**
* Returns the average number of probes per key lookup.
*/
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index b984301cbb..bf5f965a9d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -70,13 +70,14 @@ public final class UnsafeExternalSorter {
private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
// These variables are reset after spilling:
- private UnsafeInMemorySorter inMemSorter;
+ @Nullable private UnsafeInMemorySorter inMemSorter;
// Whether the in-mem sorter is created internally, or passed in from outside.
// If it is passed in from outside, we shouldn't release the in-mem sorter's memory.
private boolean isInMemSorterExternal = false;
private MemoryBlock currentPage = null;
private long currentPagePosition = -1;
private long freeSpaceInCurrentPage = 0;
+ private long peakMemoryUsedBytes = 0;
public static UnsafeExternalSorter createWithExistingInMemorySorter(
TaskMemoryManager taskMemoryManager,
@@ -183,6 +184,7 @@ public final class UnsafeExternalSorter {
* Sort and spill the current records in response to memory pressure.
*/
public void spill() throws IOException {
+ assert(inMemSorter != null);
logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
Thread.currentThread().getId(),
Utils.bytesToString(getMemoryUsage()),
@@ -219,7 +221,22 @@ public final class UnsafeExternalSorter {
for (MemoryBlock page : allocatedPages) {
totalPageSize += page.size();
}
- return inMemSorter.getMemoryUsage() + totalPageSize;
+ return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
+ }
+
+ private void updatePeakMemoryUsed() {
+ long mem = getMemoryUsage();
+ if (mem > peakMemoryUsedBytes) {
+ peakMemoryUsedBytes = mem;
+ }
+ }
+
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ public long getPeakMemoryUsedBytes() {
+ updatePeakMemoryUsed();
+ return peakMemoryUsedBytes;
}
@VisibleForTesting
@@ -233,6 +250,7 @@ public final class UnsafeExternalSorter {
* @return the number of bytes freed.
*/
public long freeMemory() {
+ updatePeakMemoryUsed();
long memoryFreed = 0;
for (MemoryBlock block : allocatedPages) {
taskMemoryManager.freePage(block);
@@ -277,7 +295,8 @@ public final class UnsafeExternalSorter {
* @return true if the record can be inserted without requiring more allocations, false otherwise.
*/
private boolean haveSpaceForRecord(int requiredSpace) {
- assert (requiredSpace > 0);
+ assert(requiredSpace > 0);
+ assert(inMemSorter != null);
return (inMemSorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
}
@@ -290,6 +309,7 @@ public final class UnsafeExternalSorter {
* the record size.
*/
private void allocateSpaceForRecord(int requiredSpace) throws IOException {
+ assert(inMemSorter != null);
// TODO: merge these steps to first calculate total memory requirements for this insert,
// then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
// data page.
@@ -350,6 +370,7 @@ public final class UnsafeExternalSorter {
if (!haveSpaceForRecord(totalSpaceRequired)) {
allocateSpaceForRecord(totalSpaceRequired);
}
+ assert(inMemSorter != null);
final long recordAddress =
taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
@@ -382,6 +403,7 @@ public final class UnsafeExternalSorter {
if (!haveSpaceForRecord(totalSpaceRequired)) {
allocateSpaceForRecord(totalSpaceRequired);
}
+ assert(inMemSorter != null);
final long recordAddress =
taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
@@ -405,6 +427,7 @@ public final class UnsafeExternalSorter {
}
public UnsafeSorterIterator getSortedIterator() throws IOException {
+ assert(inMemSorter != null);
final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator();
int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
if (spillWriters.isEmpty()) {
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index b1cef47042..648cd1b104 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -207,7 +207,7 @@ span.additional-metric-title {
/* Hide all additional metrics by default. This is done here rather than using JavaScript to
* avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */
.scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote,
-.serialization_time, .getting_result_time {
+.serialization_time, .getting_result_time, .peak_execution_memory {
display: none;
}
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index eb75f26718..b6a0119c69 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -152,8 +152,14 @@ class Accumulable[R, T] private[spark] (
in.defaultReadObject()
value_ = zero
deserialized = true
- val taskContext = TaskContext.get()
- taskContext.registerAccumulator(this)
+ // Automatically register the accumulator when it is deserialized with the task closure.
+ // Note that internal accumulators are deserialized before the TaskContext is created and
+ // are registered in the TaskContext constructor.
+ if (!isInternal) {
+ val taskContext = TaskContext.get()
+ assume(taskContext != null, "Task context was null when deserializing user accumulators")
+ taskContext.registerAccumulator(this)
+ }
}
override def toString: String = if (value_ == null) "null" else value_.toString
@@ -248,10 +254,20 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
* @param param helper object defining how to add elements of type `T`
* @tparam T result type
*/
-class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String])
- extends Accumulable[T, T](initialValue, param, name) {
+class Accumulator[T] private[spark] (
+ @transient initialValue: T,
+ param: AccumulatorParam[T],
+ name: Option[String],
+ internal: Boolean)
+ extends Accumulable[T, T](initialValue, param, name, internal) {
+
+ def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = {
+ this(initialValue, param, name, false)
+ }
- def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None)
+ def this(initialValue: T, param: AccumulatorParam[T]) = {
+ this(initialValue, param, None, false)
+ }
}
/**
@@ -342,3 +358,37 @@ private[spark] object Accumulators extends Logging {
}
}
+
+private[spark] object InternalAccumulator {
+ val PEAK_EXECUTION_MEMORY = "peakExecutionMemory"
+ val TEST_ACCUMULATOR = "testAccumulator"
+
+ // For testing only.
+ // This needs to be a def since we don't want to reuse the same accumulator across stages.
+ private def maybeTestAccumulator: Option[Accumulator[Long]] = {
+ if (sys.props.contains("spark.testing")) {
+ Some(new Accumulator(
+ 0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true))
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Accumulators for tracking internal metrics.
+ *
+ * These accumulators are created with the stage such that all tasks in the stage will
+ * add to the same set of accumulators. We do this to report the distribution of accumulator
+ * values across all tasks within each stage.
+ */
+ def create(): Seq[Accumulator[Long]] = {
+ Seq(
+ // Execution memory refers to the memory used by internal data structures created
+ // during shuffles, aggregations and joins. The value of this accumulator should be
+ // approximately the sum of the peak sizes across all such data structures created
+ // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort.
+ new Accumulator(
+ 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true)
+ ) ++ maybeTestAccumulator.toSeq
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index ceeb58075d..289aab9bd9 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -58,12 +58,7 @@ case class Aggregator[K, V, C] (
} else {
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
combiners.insertAll(iter)
- // Update task metrics if context is not null
- // TODO: Make context non optional in a future release
- Option(context).foreach { c =>
- c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
- c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
- }
+ updateMetrics(context, combiners)
combiners.iterator
}
}
@@ -89,13 +84,18 @@ case class Aggregator[K, V, C] (
} else {
val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
combiners.insertAll(iter)
- // Update task metrics if context is not null
- // TODO: Make context non-optional in a future release
- Option(context).foreach { c =>
- c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
- c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
- }
+ updateMetrics(context, combiners)
combiners.iterator
}
}
+
+ /** Update task metrics after populating the external map. */
+ private def updateMetrics(context: TaskContext, map: ExternalAppendOnlyMap[_, _, _]): Unit = {
+ Option(context).foreach { c =>
+ c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
+ c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
+ c.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes)
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 5d2c551d58..63cca80b2d 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -61,12 +61,12 @@ object TaskContext {
protected[spark] def unset(): Unit = taskContext.remove()
/**
- * Return an empty task context that is not actually used.
- * Internal use only.
+ * An empty task context that does not represent an actual task.
*/
- private[spark] def empty(): TaskContext = {
- new TaskContextImpl(0, 0, 0, 0, null, null)
+ private[spark] def empty(): TaskContextImpl = {
+ new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty)
}
+
}
@@ -187,4 +187,9 @@ abstract class TaskContext extends Serializable {
* accumulator id and the value of the Map is the latest accumulator local value.
*/
private[spark] def collectAccumulators(): Map[Long, Any]
+
+ /**
+ * Accumulators for tracking internal metrics indexed by the name.
+ */
+ private[spark] val internalMetricsToAccumulators: Map[String, Accumulator[Long]]
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 9ee168ae01..5df94c6d3a 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -32,6 +32,7 @@ private[spark] class TaskContextImpl(
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
@transient private val metricsSystem: MetricsSystem,
+ internalAccumulators: Seq[Accumulator[Long]],
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
@@ -114,4 +115,11 @@ private[spark] class TaskContextImpl(
private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized {
accumulators.mapValues(_.localValue).toMap
}
+
+ private[spark] override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = {
+ // Explicitly register internal accumulators here because these are
+ // not captured in the task closure and are already deserialized
+ internalAccumulators.foreach(registerAccumulator)
+ internalAccumulators.map { a => (a.name.get, a) }.toMap
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 130b58882d..9c617fc719 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -23,8 +23,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
-import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
+import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer}
import org.apache.spark.util.Utils
@@ -169,8 +168,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
for ((it, depNum) <- rddIterators) {
map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
}
- context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled)
- context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled)
+ context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
+ context.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes)
new InterruptibleIterator(context,
map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
index e0edd7d4ae..11d123eec4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
@@ -24,11 +24,12 @@ import org.apache.spark.annotation.DeveloperApi
* Information about an [[org.apache.spark.Accumulable]] modified during a task or stage.
*/
@DeveloperApi
-class AccumulableInfo (
+class AccumulableInfo private[spark] (
val id: Long,
val name: String,
val update: Option[String], // represents a partial update within a task
- val value: String) {
+ val value: String,
+ val internal: Boolean) {
override def equals(other: Any): Boolean = other match {
case acc: AccumulableInfo =>
@@ -40,10 +41,10 @@ class AccumulableInfo (
object AccumulableInfo {
def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = {
- new AccumulableInfo(id, name, update, value)
+ new AccumulableInfo(id, name, update, value, internal = false)
}
def apply(id: Long, name: String, value: String): AccumulableInfo = {
- new AccumulableInfo(id, name, None, value)
+ new AccumulableInfo(id, name, None, value, internal = false)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index c4fa277c21..bb489c6b6e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -773,16 +773,26 @@ class DAGScheduler(
stage.pendingTasks.clear()
// First figure out the indexes of partition ids to compute.
- val partitionsToCompute: Seq[Int] = {
+ val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = {
stage match {
case stage: ShuffleMapStage =>
- (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty)
+ val allPartitions = 0 until stage.numPartitions
+ val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty }
+ (allPartitions, filteredPartitions)
case stage: ResultStage =>
val job = stage.resultOfJob.get
- (0 until job.numPartitions).filter(id => !job.finished(id))
+ val allPartitions = 0 until job.numPartitions
+ val filteredPartitions = allPartitions.filter { id => !job.finished(id) }
+ (allPartitions, filteredPartitions)
}
}
+ // Reset internal accumulators only if this stage is not partially submitted
+ // Otherwise, we may override existing accumulator values from some tasks
+ if (allPartitions == partitionsToCompute) {
+ stage.resetInternalAccumulators()
+ }
+
val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull
runningStages += stage
@@ -852,7 +862,8 @@ class DAGScheduler(
partitionsToCompute.map { id =>
val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
- new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
+ new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
+ taskBinary, part, locs, stage.internalAccumulators)
}
case stage: ResultStage =>
@@ -861,7 +872,8 @@ class DAGScheduler(
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = taskIdToLocations(id)
- new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
+ new ResultTask(stage.id, stage.latestInfo.attemptId,
+ taskBinary, part, locs, id, stage.internalAccumulators)
}
}
} catch {
@@ -916,9 +928,11 @@ class DAGScheduler(
// To avoid UI cruft, ignore cases where value wasn't updated
if (acc.name.isDefined && partialValue != acc.zero) {
val name = acc.name.get
- stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}")
+ val value = s"${acc.value}"
+ stage.latestInfo.accumulables(id) =
+ new AccumulableInfo(id, name, None, value, acc.isInternal)
event.taskInfo.accumulables +=
- AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}")
+ new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal)
}
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 9c2606e278..c4dc080e2b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -45,8 +45,10 @@ private[spark] class ResultTask[T, U](
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient locs: Seq[TaskLocation],
- val outputId: Int)
- extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {
+ val outputId: Int,
+ internalAccumulators: Seq[Accumulator[Long]])
+ extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators)
+ with Serializable {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 14c8c00961..f478f9982a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -43,12 +43,14 @@ private[spark] class ShuffleMapTask(
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
- @transient private var locs: Seq[TaskLocation])
- extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {
+ @transient private var locs: Seq[TaskLocation],
+ internalAccumulators: Seq[Accumulator[Long]])
+ extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators)
+ with Logging {
/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
- this(0, 0, null, new Partition { override def index: Int = 0 }, null)
+ this(0, 0, null, new Partition { override def index: Int = 0 }, null, null)
}
@transient private val preferredLocs: Seq[TaskLocation] = {
@@ -69,7 +71,7 @@ private[spark] class ShuffleMapTask(
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
- return writer.stop(success = true).get
+ writer.stop(success = true).get
} catch {
case e: Exception =>
try {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 40a333a3e0..de05ee256d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -68,6 +68,22 @@ private[spark] abstract class Stage(
val name = callSite.shortForm
val details = callSite.longForm
+ private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty
+
+ /** Internal accumulators shared across all tasks in this stage. */
+ def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators
+
+ /**
+ * Re-initialize the internal accumulators associated with this stage.
+ *
+ * This is called every time the stage is submitted, *except* when a subset of tasks
+ * belonging to this stage has already finished. Otherwise, reinitializing the internal
+ * accumulators here again will override partial values from the finished tasks.
+ */
+ def resetInternalAccumulators(): Unit = {
+ _internalAccumulators = InternalAccumulator.create()
+ }
+
/**
* Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized
* here, before any attempts have actually been created, because the DAGScheduler uses this
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 1978305cfe..9edf9f048f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -23,7 +23,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.HashMap
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext}
+import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.unsafe.memory.TaskMemoryManager
@@ -47,7 +47,8 @@ import org.apache.spark.util.Utils
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
- var partitionId: Int) extends Serializable {
+ val partitionId: Int,
+ internalAccumulators: Seq[Accumulator[Long]]) extends Serializable {
/**
* The key of the Map is the accumulator id and the value of the Map is the latest accumulator
@@ -68,12 +69,13 @@ private[spark] abstract class Task[T](
metricsSystem: MetricsSystem)
: (T, AccumulatorUpdates) = {
context = new TaskContextImpl(
- stageId = stageId,
- partitionId = partitionId,
- taskAttemptId = taskAttemptId,
- attemptNumber = attemptNumber,
- taskMemoryManager = taskMemoryManager,
- metricsSystem = metricsSystem,
+ stageId,
+ partitionId,
+ taskAttemptId,
+ attemptNumber,
+ taskMemoryManager,
+ metricsSystem,
+ internalAccumulators,
runningLocally = false)
TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index de79fa56f0..0c8f08f0f3 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -17,7 +17,7 @@
package org.apache.spark.shuffle.hash
-import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
@@ -100,8 +100,10 @@ private[spark] class HashShuffleReader[K, C](
// the ExternalSorter won't spill to disk.
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
- context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)
- context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
sorter.iterator
case None =>
aggregatedIter
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index e2d25e3636..cb122eaed8 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -62,6 +62,13 @@ private[spark] object ToolTips {
"""Time that the executor spent paused for Java garbage collection while the task was
running."""
+ val PEAK_EXECUTION_MEMORY =
+ """Execution memory refers to the memory used by internal data structures created during
+ shuffles, aggregations and joins when Tungsten is enabled. The value of this accumulator
+ should be approximately the sum of the peak sizes across all such data structures created
+ in this task. For SQL jobs, this only tracks all unsafe operators, broadcast joins, and
+ external sort."""
+
val JOB_TIMELINE =
"""Shows when jobs started and ended and when executors joined or left. Drag to scroll.
Click Enable Zooming and use mouse wheel to zoom in/out."""
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index cf04b5e592..3954c3d1ef 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -26,6 +26,7 @@ import scala.xml.{Elem, Node, Unparsed}
import org.apache.commons.lang3.StringEscapeUtils
+import org.apache.spark.{InternalAccumulator, SparkConf}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
import org.apache.spark.ui._
@@ -67,6 +68,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
// if we find that it's okay.
private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000)
+ private val displayPeakExecutionMemory =
+ parent.conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean)
def render(request: HttpServletRequest): Seq[Node] = {
progressListener.synchronized {
@@ -114,10 +117,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val stageData = stageDataOption.get
val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime)
-
val numCompleted = tasks.count(_.taskInfo.finished)
- val accumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables
- val hasAccumulators = accumulables.size > 0
+
+ val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables
+ val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal }
+ val hasAccumulators = externalAccumulables.size > 0
val summary =
<div>
@@ -221,6 +225,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
<span class="additional-metric-title">Getting Result Time</span>
</span>
</li>
+ {if (displayPeakExecutionMemory) {
+ <li>
+ <span data-toggle="tooltip"
+ title={ToolTips.PEAK_EXECUTION_MEMORY} data-placement="right">
+ <input type="checkbox" name={TaskDetailsClassNames.PEAK_EXECUTION_MEMORY}/>
+ <span class="additional-metric-title">Peak Execution Memory</span>
+ </span>
+ </li>
+ }}
</ul>
</div>
</div>
@@ -241,11 +254,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val accumulableTable = UIUtils.listingTable(
accumulableHeaders,
accumulableRow,
- accumulables.values.toSeq)
+ externalAccumulables.toSeq)
val currentTime = System.currentTimeMillis()
val (taskTable, taskTableHTML) = try {
val _taskTable = new TaskPagedTable(
+ parent.conf,
UIUtils.prependBaseUri(parent.basePath) +
s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}",
tasks,
@@ -294,12 +308,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
else {
def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] =
Distribution(data).get.getQuantiles()
-
def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = {
getDistributionQuantiles(times).map { millis =>
<td>{UIUtils.formatDuration(millis.toLong)}</td>
}
}
+ def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = {
+ getDistributionQuantiles(data).map(d => <td>{Utils.bytesToString(d.toLong)}</td>)
+ }
val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
metrics.get.executorDeserializeTime.toDouble
@@ -349,6 +365,23 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
</span>
</td> +:
getFormattedTimeQuantiles(gettingResultTimes)
+
+ val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) =>
+ info.accumulables
+ .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY }
+ .map { acc => acc.value.toLong }
+ .getOrElse(0L)
+ .toDouble
+ }
+ val peakExecutionMemoryQuantiles = {
+ <td>
+ <span data-toggle="tooltip"
+ title={ToolTips.PEAK_EXECUTION_MEMORY} data-placement="right">
+ Peak Execution Memory
+ </span>
+ </td> +: getFormattedSizeQuantiles(peakExecutionMemory)
+ }
+
// The scheduler delay includes the network delay to send the task to the worker
// machine and to send back the result (but not the time to fetch the task result,
// if it needed to be fetched from the block manager on the worker).
@@ -359,10 +392,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
title={ToolTips.SCHEDULER_DELAY} data-placement="right">Scheduler Delay</span></td>
val schedulerDelayQuantiles = schedulerDelayTitle +:
getFormattedTimeQuantiles(schedulerDelays)
-
- def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] =
- getDistributionQuantiles(data).map(d => <td>{Utils.bytesToString(d.toLong)}</td>)
-
def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double])
: Seq[Elem] = {
val recordDist = getDistributionQuantiles(records).iterator
@@ -466,6 +495,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
{serializationQuantiles}
</tr>,
<tr class={TaskDetailsClassNames.GETTING_RESULT_TIME}>{gettingResultQuantiles}</tr>,
+ if (displayPeakExecutionMemory) {
+ <tr class={TaskDetailsClassNames.PEAK_EXECUTION_MEMORY}>
+ {peakExecutionMemoryQuantiles}
+ </tr>
+ } else {
+ Nil
+ },
if (stageData.hasInput) <tr>{inputQuantiles}</tr> else Nil,
if (stageData.hasOutput) <tr>{outputQuantiles}</tr> else Nil,
if (stageData.hasShuffleRead) {
@@ -499,7 +535,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val executorTable = new ExecutorTable(stageId, stageAttemptId, parent)
val maybeAccumulableTable: Seq[Node] =
- if (accumulables.size > 0) { <h4>Accumulators</h4> ++ accumulableTable } else Seq()
+ if (hasAccumulators) { <h4>Accumulators</h4> ++ accumulableTable } else Seq()
val content =
summary ++
@@ -750,29 +786,30 @@ private[ui] case class TaskTableRowBytesSpilledData(
* Contains all data that needs for sorting and generating HTML. Using this one rather than
* TaskUIData to avoid creating duplicate contents during sorting the data.
*/
-private[ui] case class TaskTableRowData(
- index: Int,
- taskId: Long,
- attempt: Int,
- speculative: Boolean,
- status: String,
- taskLocality: String,
- executorIdAndHost: String,
- launchTime: Long,
- duration: Long,
- formatDuration: String,
- schedulerDelay: Long,
- taskDeserializationTime: Long,
- gcTime: Long,
- serializationTime: Long,
- gettingResultTime: Long,
- accumulators: Option[String], // HTML
- input: Option[TaskTableRowInputData],
- output: Option[TaskTableRowOutputData],
- shuffleRead: Option[TaskTableRowShuffleReadData],
- shuffleWrite: Option[TaskTableRowShuffleWriteData],
- bytesSpilled: Option[TaskTableRowBytesSpilledData],
- error: String)
+private[ui] class TaskTableRowData(
+ val index: Int,
+ val taskId: Long,
+ val attempt: Int,
+ val speculative: Boolean,
+ val status: String,
+ val taskLocality: String,
+ val executorIdAndHost: String,
+ val launchTime: Long,
+ val duration: Long,
+ val formatDuration: String,
+ val schedulerDelay: Long,
+ val taskDeserializationTime: Long,
+ val gcTime: Long,
+ val serializationTime: Long,
+ val gettingResultTime: Long,
+ val peakExecutionMemoryUsed: Long,
+ val accumulators: Option[String], // HTML
+ val input: Option[TaskTableRowInputData],
+ val output: Option[TaskTableRowOutputData],
+ val shuffleRead: Option[TaskTableRowShuffleReadData],
+ val shuffleWrite: Option[TaskTableRowShuffleWriteData],
+ val bytesSpilled: Option[TaskTableRowBytesSpilledData],
+ val error: String)
private[ui] class TaskDataSource(
tasks: Seq[TaskUIData],
@@ -816,10 +853,15 @@ private[ui] class TaskDataSource(
val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
val gettingResultTime = getGettingResultTime(info, currentTime)
- val maybeAccumulators = info.accumulables
- val accumulatorsReadable = maybeAccumulators.map { acc =>
+ val (taskInternalAccumulables, taskExternalAccumulables) =
+ info.accumulables.partition(_.internal)
+ val externalAccumulableReadable = taskExternalAccumulables.map { acc =>
StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")
}
+ val peakExecutionMemoryUsed = taskInternalAccumulables
+ .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY }
+ .map { acc => acc.value.toLong }
+ .getOrElse(0L)
val maybeInput = metrics.flatMap(_.inputMetrics)
val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L)
@@ -923,7 +965,7 @@ private[ui] class TaskDataSource(
None
}
- TaskTableRowData(
+ new TaskTableRowData(
info.index,
info.taskId,
info.attempt,
@@ -939,7 +981,8 @@ private[ui] class TaskDataSource(
gcTime,
serializationTime,
gettingResultTime,
- if (hasAccumulators) Some(accumulatorsReadable.mkString("<br/>")) else None,
+ peakExecutionMemoryUsed,
+ if (hasAccumulators) Some(externalAccumulableReadable.mkString("<br/>")) else None,
input,
output,
shuffleRead,
@@ -1006,6 +1049,10 @@ private[ui] class TaskDataSource(
override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime)
}
+ case "Peak Execution Memory" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.peakExecutionMemoryUsed, y.peakExecutionMemoryUsed)
+ }
case "Accumulators" =>
if (hasAccumulators) {
new Ordering[TaskTableRowData] {
@@ -1132,6 +1179,7 @@ private[ui] class TaskDataSource(
}
private[ui] class TaskPagedTable(
+ conf: SparkConf,
basePath: String,
data: Seq[TaskUIData],
hasAccumulators: Boolean,
@@ -1143,7 +1191,11 @@ private[ui] class TaskPagedTable(
currentTime: Long,
pageSize: Int,
sortColumn: String,
- desc: Boolean) extends PagedTable[TaskTableRowData]{
+ desc: Boolean) extends PagedTable[TaskTableRowData] {
+
+ // We only track peak memory used for unsafe operators
+ private val displayPeakExecutionMemory =
+ conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean)
override def tableId: String = ""
@@ -1195,6 +1247,13 @@ private[ui] class TaskPagedTable(
("GC Time", ""),
("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME),
("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++
+ {
+ if (displayPeakExecutionMemory) {
+ Seq(("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY))
+ } else {
+ Nil
+ }
+ } ++
{if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++
{if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++
{if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++
@@ -1271,6 +1330,11 @@ private[ui] class TaskPagedTable(
<td class={TaskDetailsClassNames.GETTING_RESULT_TIME}>
{UIUtils.formatDuration(task.gettingResultTime)}
</td>
+ {if (displayPeakExecutionMemory) {
+ <td class={TaskDetailsClassNames.PEAK_EXECUTION_MEMORY}>
+ {Utils.bytesToString(task.peakExecutionMemoryUsed)}
+ </td>
+ }}
{if (task.accumulators.nonEmpty) {
<td>{Unparsed(task.accumulators.get)}</td>
}}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
index 9bf67db8ac..d2dfc5a329 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
@@ -31,4 +31,5 @@ private[spark] object TaskDetailsClassNames {
val SHUFFLE_READ_REMOTE_SIZE = "shuffle_read_remote"
val RESULT_SERIALIZATION_TIME = "serialization_time"
val GETTING_RESULT_TIME = "getting_result_time"
+ val PEAK_EXECUTION_MEMORY = "peak_execution_memory"
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index d166037351..f929b12606 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -89,6 +89,7 @@ class ExternalAppendOnlyMap[K, V, C](
// Number of bytes spilled in total
private var _diskBytesSpilled = 0L
+ def diskBytesSpilled: Long = _diskBytesSpilled
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val fileBufferSize =
@@ -97,6 +98,10 @@ class ExternalAppendOnlyMap[K, V, C](
// Write metrics for current spill
private var curWriteMetrics: ShuffleWriteMetrics = _
+ // Peak size of the in-memory map observed so far, in bytes
+ private var _peakMemoryUsedBytes: Long = 0L
+ def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes
+
private val keyComparator = new HashComparator[K]
private val ser = serializer.newInstance()
@@ -126,7 +131,11 @@ class ExternalAppendOnlyMap[K, V, C](
while (entries.hasNext) {
curEntry = entries.next()
- if (maybeSpill(currentMap, currentMap.estimateSize())) {
+ val estimatedSize = currentMap.estimateSize()
+ if (estimatedSize > _peakMemoryUsedBytes) {
+ _peakMemoryUsedBytes = estimatedSize
+ }
+ if (maybeSpill(currentMap, estimatedSize)) {
currentMap = new SizeTrackingAppendOnlyMap[K, C]
}
currentMap.changeValue(curEntry._1, update)
@@ -207,8 +216,6 @@ class ExternalAppendOnlyMap[K, V, C](
spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
}
- def diskBytesSpilled: Long = _diskBytesSpilled
-
/**
* Return an iterator that merges the in-memory map with the spilled maps.
* If no spill has occurred, simply return the in-memory map's iterator.
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index ba7ec834d6..19287edbaf 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -152,6 +152,9 @@ private[spark] class ExternalSorter[K, V, C](
private var _diskBytesSpilled = 0L
def diskBytesSpilled: Long = _diskBytesSpilled
+ // Peak size of the in-memory data structure observed so far, in bytes
+ private var _peakMemoryUsedBytes: Long = 0L
+ def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes
// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
// Can be a partial ordering by hash code if a total ordering is not provided through by the
@@ -224,15 +227,22 @@ private[spark] class ExternalSorter[K, V, C](
return
}
+ var estimatedSize = 0L
if (usingMap) {
- if (maybeSpill(map, map.estimateSize())) {
+ estimatedSize = map.estimateSize()
+ if (maybeSpill(map, estimatedSize)) {
map = new PartitionedAppendOnlyMap[K, C]
}
} else {
- if (maybeSpill(buffer, buffer.estimateSize())) {
+ estimatedSize = buffer.estimateSize()
+ if (maybeSpill(buffer, estimatedSize)) {
buffer = newBuffer()
}
}
+
+ if (estimatedSize > _peakMemoryUsedBytes) {
+ _peakMemoryUsedBytes = estimatedSize
+ }
}
/**
@@ -684,8 +694,10 @@ private[spark] class ExternalSorter[K, V, C](
}
}
- context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
- context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
+ context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
+ context.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes)
lengths
}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index e948ca3347..ffe4b4baff 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -51,7 +51,6 @@ import org.junit.Test;
import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.*;
-import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.partial.BoundedDouble;
import org.apache.spark.partial.PartialResult;
@@ -1011,7 +1010,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, null, false, new TaskMetrics());
+ TaskContext context = TaskContext$.MODULE$.empty();
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
index 04fc09b323..98c32bbc29 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -190,6 +190,7 @@ public class UnsafeShuffleWriterSuite {
});
when(taskContext.taskMetrics()).thenReturn(taskMetrics);
+ when(taskContext.internalMetricsToAccumulators()).thenReturn(null);
when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
@@ -542,4 +543,57 @@ public class UnsafeShuffleWriterSuite {
writer.stop(false);
assertSpillFilesWereCleanedUp();
}
+
+ @Test
+ public void testPeakMemoryUsed() throws Exception {
+ final long recordLengthBytes = 8;
+ final long pageSizeBytes = 256;
+ final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
+ final SparkConf conf = new SparkConf().set("spark.buffer.pageSize", pageSizeBytes + "b");
+ final UnsafeShuffleWriter<Object, Object> writer =
+ new UnsafeShuffleWriter<Object, Object>(
+ blockManager,
+ shuffleBlockResolver,
+ taskMemoryManager,
+ shuffleMemoryManager,
+ new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
+ 0, // map id
+ taskContext,
+ conf);
+
+ // Peak memory should be monotonically increasing. More specifically, every time
+ // we allocate a new page it should increase by exactly the size of the page.
+ long previousPeakMemory = writer.getPeakMemoryUsedBytes();
+ long newPeakMemory;
+ try {
+ for (int i = 0; i < numRecordsPerPage * 10; i++) {
+ writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+ newPeakMemory = writer.getPeakMemoryUsedBytes();
+ if (i % numRecordsPerPage == 0) {
+ // We allocated a new page for this record, so peak memory should change
+ assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
+ } else {
+ assertEquals(previousPeakMemory, newPeakMemory);
+ }
+ previousPeakMemory = newPeakMemory;
+ }
+
+ // Spilling should not change peak memory
+ writer.forceSorterToSpill();
+ newPeakMemory = writer.getPeakMemoryUsedBytes();
+ assertEquals(previousPeakMemory, newPeakMemory);
+ for (int i = 0; i < numRecordsPerPage; i++) {
+ writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+ }
+ newPeakMemory = writer.getPeakMemoryUsedBytes();
+ assertEquals(previousPeakMemory, newPeakMemory);
+
+ // Closing the writer should not change peak memory
+ writer.closeAndWriteOutput();
+ newPeakMemory = writer.getPeakMemoryUsedBytes();
+ assertEquals(previousPeakMemory, newPeakMemory);
+ } finally {
+ writer.stop(false);
+ }
+ }
}
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index dbb7c662d7..0e23a64fb7 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -25,6 +25,7 @@ import org.junit.*;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import static org.hamcrest.Matchers.greaterThan;
+import static org.junit.Assert.*;
import static org.mockito.AdditionalMatchers.geq;
import static org.mockito.Mockito.*;
@@ -495,4 +496,42 @@ public abstract class AbstractBytesToBytesMapSuite {
map.growAndRehash();
map.free();
}
+
+ @Test
+ public void testTotalMemoryConsumption() {
+ final long recordLengthBytes = 24;
+ final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker
+ final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes;
+ final BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, 1024, pageSizeBytes);
+
+ // Since BytesToBytesMap is append-only, we expect the total memory consumption to be
+ // monotonically increasing. More specifically, every time we allocate a new page it
+ // should increase by exactly the size of the page. In this regard, the memory usage
+ // at any given time is also the peak memory used.
+ long previousMemory = map.getTotalMemoryConsumption();
+ long newMemory;
+ try {
+ for (long i = 0; i < numRecordsPerPage * 10; i++) {
+ final long[] value = new long[]{i};
+ map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8).putNewKey(
+ value,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 8,
+ value,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 8);
+ newMemory = map.getTotalMemoryConsumption();
+ if (i % numRecordsPerPage == 0) {
+ // We allocated a new page for this record, so peak memory should change
+ assertEquals(previousMemory + pageSizeBytes, newMemory);
+ } else {
+ assertEquals(previousMemory, newMemory);
+ }
+ previousMemory = newMemory;
+ }
+ } finally {
+ map.free();
+ }
+ }
}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 52fa8bcd57..c11949d57a 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -247,4 +247,50 @@ public class UnsafeExternalSorterSuite {
assertSpillFilesWereCleanedUp();
}
+ @Test
+ public void testPeakMemoryUsed() throws Exception {
+ final long recordLengthBytes = 8;
+ final long pageSizeBytes = 256;
+ final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
+ final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
+ taskMemoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ 1024,
+ pageSizeBytes);
+
+ // Peak memory should be monotonically increasing. More specifically, every time
+ // we allocate a new page it should increase by exactly the size of the page.
+ long previousPeakMemory = sorter.getPeakMemoryUsedBytes();
+ long newPeakMemory;
+ try {
+ for (int i = 0; i < numRecordsPerPage * 10; i++) {
+ insertNumber(sorter, i);
+ newPeakMemory = sorter.getPeakMemoryUsedBytes();
+ if (i % numRecordsPerPage == 0) {
+ // We allocated a new page for this record, so peak memory should change
+ assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
+ } else {
+ assertEquals(previousPeakMemory, newPeakMemory);
+ }
+ previousPeakMemory = newPeakMemory;
+ }
+
+ // Spilling should not change peak memory
+ sorter.spill();
+ newPeakMemory = sorter.getPeakMemoryUsedBytes();
+ assertEquals(previousPeakMemory, newPeakMemory);
+ for (int i = 0; i < numRecordsPerPage; i++) {
+ insertNumber(sorter, i);
+ }
+ newPeakMemory = sorter.getPeakMemoryUsedBytes();
+ assertEquals(previousPeakMemory, newPeakMemory);
+ } finally {
+ sorter.freeMemory();
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index e942d6579b..48f549575f 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -18,13 +18,17 @@
package org.apache.spark
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
import scala.ref.WeakReference
import org.scalatest.Matchers
+import org.scalatest.exceptions.TestFailedException
+import org.apache.spark.scheduler._
-class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext {
+class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext {
+ import InternalAccumulator._
implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] =
new AccumulableParam[mutable.Set[A], A] {
@@ -155,4 +159,191 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
assert(!Accumulators.originals.get(accId).isDefined)
}
+ test("internal accumulators in TaskContext") {
+ val accums = InternalAccumulator.create()
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums)
+ val internalMetricsToAccums = taskContext.internalMetricsToAccumulators
+ val collectedInternalAccums = taskContext.collectInternalAccumulators()
+ val collectedAccums = taskContext.collectAccumulators()
+ assert(internalMetricsToAccums.size > 0)
+ assert(internalMetricsToAccums.values.forall(_.isInternal))
+ assert(internalMetricsToAccums.contains(TEST_ACCUMULATOR))
+ val testAccum = internalMetricsToAccums(TEST_ACCUMULATOR)
+ assert(collectedInternalAccums.size === internalMetricsToAccums.size)
+ assert(collectedInternalAccums.size === collectedAccums.size)
+ assert(collectedInternalAccums.contains(testAccum.id))
+ assert(collectedAccums.contains(testAccum.id))
+ }
+
+ test("internal accumulators in a stage") {
+ val listener = new SaveInfoListener
+ val numPartitions = 10
+ sc = new SparkContext("local", "test")
+ sc.addSparkListener(listener)
+ // Have each task add 1 to the internal accumulator
+ sc.parallelize(1 to 100, numPartitions).mapPartitions { iter =>
+ TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
+ iter
+ }.count()
+ val stageInfos = listener.getCompletedStageInfos
+ val taskInfos = listener.getCompletedTaskInfos
+ assert(stageInfos.size === 1)
+ assert(taskInfos.size === numPartitions)
+ // The accumulator values should be merged in the stage
+ val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR)
+ assert(stageAccum.value.toLong === numPartitions)
+ // The accumulator should be updated locally on each task
+ val taskAccumValues = taskInfos.map { taskInfo =>
+ val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR)
+ assert(taskAccum.update.isDefined)
+ assert(taskAccum.update.get.toLong === 1)
+ taskAccum.value.toLong
+ }
+ // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions
+ assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
+ }
+
+ test("internal accumulators in multiple stages") {
+ val listener = new SaveInfoListener
+ val numPartitions = 10
+ sc = new SparkContext("local", "test")
+ sc.addSparkListener(listener)
+ // Each stage creates its own set of internal accumulators so the
+ // values for the same metric should not be mixed up across stages
+ sc.parallelize(1 to 100, numPartitions)
+ .map { i => (i, i) }
+ .mapPartitions { iter =>
+ TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
+ iter
+ }
+ .reduceByKey { case (x, y) => x + y }
+ .mapPartitions { iter =>
+ TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 10
+ iter
+ }
+ .repartition(numPartitions * 2)
+ .mapPartitions { iter =>
+ TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100
+ iter
+ }
+ .count()
+ // We ran 3 stages, and the accumulator values should be distinct
+ val stageInfos = listener.getCompletedStageInfos
+ assert(stageInfos.size === 3)
+ val firstStageAccum = findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR)
+ val secondStageAccum = findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR)
+ val thirdStageAccum = findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR)
+ assert(firstStageAccum.value.toLong === numPartitions)
+ assert(secondStageAccum.value.toLong === numPartitions * 10)
+ assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100)
+ }
+
+ test("internal accumulators in fully resubmitted stages") {
+ testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks
+ }
+
+ test("internal accumulators in partially resubmitted stages") {
+ testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset
+ }
+
+ /**
+ * Return the accumulable info that matches the specified name.
+ */
+ private def findAccumulableInfo(
+ accums: Iterable[AccumulableInfo],
+ name: String): AccumulableInfo = {
+ accums.find { a => a.name == name }.getOrElse {
+ throw new TestFailedException(s"internal accumulator '$name' not found", 0)
+ }
+ }
+
+ /**
+ * Test whether internal accumulators are merged properly if some tasks fail.
+ */
+ private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = {
+ val listener = new SaveInfoListener
+ val numPartitions = 10
+ val numFailedPartitions = (0 until numPartitions).count(failCondition)
+ // This says use 1 core and retry tasks up to 2 times
+ sc = new SparkContext("local[1, 2]", "test")
+ sc.addSparkListener(listener)
+ sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) =>
+ val taskContext = TaskContext.get()
+ taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
+ // Fail the first attempts of a subset of the tasks
+ if (failCondition(i) && taskContext.attemptNumber() == 0) {
+ throw new Exception("Failing a task intentionally.")
+ }
+ iter
+ }.count()
+ val stageInfos = listener.getCompletedStageInfos
+ val taskInfos = listener.getCompletedTaskInfos
+ assert(stageInfos.size === 1)
+ assert(taskInfos.size === numPartitions + numFailedPartitions)
+ val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR)
+ // We should not double count values in the merged accumulator
+ assert(stageAccum.value.toLong === numPartitions)
+ val taskAccumValues = taskInfos.flatMap { taskInfo =>
+ if (!taskInfo.failed) {
+ // If a task succeeded, its update value should always be 1
+ val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR)
+ assert(taskAccum.update.isDefined)
+ assert(taskAccum.update.get.toLong === 1)
+ Some(taskAccum.value.toLong)
+ } else {
+ // If a task failed, we should not get its accumulator values
+ assert(taskInfo.accumulables.isEmpty)
+ None
+ }
+ }
+ assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
+ }
+
+}
+
+private[spark] object AccumulatorSuite {
+
+ /**
+ * Run one or more Spark jobs and verify that the peak execution memory accumulator
+ * is updated afterwards.
+ */
+ def verifyPeakExecutionMemorySet(
+ sc: SparkContext,
+ testName: String)(testBody: => Unit): Unit = {
+ val listener = new SaveInfoListener
+ sc.addSparkListener(listener)
+ // Verify that the accumulator does not already exist
+ sc.parallelize(1 to 10).count()
+ val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values)
+ assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY))
+ testBody
+ // Verify that peak execution memory is updated
+ val accum = listener.getCompletedStageInfos
+ .flatMap(_.accumulables.values)
+ .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)
+ .getOrElse {
+ throw new TestFailedException(
+ s"peak execution memory accumulator not set in '$testName'", 0)
+ }
+ assert(accum.value.toLong > 0)
+ }
+}
+
+/**
+ * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs.
+ */
+private class SaveInfoListener extends SparkListener {
+ private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo]
+ private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo]
+
+ def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq
+ def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq
+
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
+ completedStageInfos += stageCompleted.stageInfo
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+ completedTaskInfos += taskEnd.taskInfo
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 618a5fb247..cb8bd04e49 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -21,7 +21,7 @@ import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter
import org.scalatest.mock.MockitoSugar
-import org.apache.spark.executor.DataReadMethod
+import org.apache.spark.executor.{DataReadMethod, TaskMetrics}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
@@ -65,7 +65,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
// in blockManager.put is a losing battle. You have been warned.
blockManager = sc.env.blockManager
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+ val context = TaskContext.empty()
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
assert(computeValue.toList === List(1, 2, 3, 4))
@@ -77,7 +77,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12)
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result))
- val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+ val context = TaskContext.empty()
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
@@ -86,14 +86,14 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
// Local computation should not persist the resulting value, so don't expect a put().
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None)
- val context = new TaskContextImpl(0, 0, 0, 0, null, null, true)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty, runningLocally = true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
test("verify task metrics updated correctly") {
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+ val context = TaskContext.empty()
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index 3e8816a4c6..5f73ec8675 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -175,7 +175,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
- val tContext = new TaskContextImpl(0, 0, 0, 0, null, null)
+ val tContext = TaskContext.empty()
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index b3ca150195..f7e16af9d3 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -19,9 +19,11 @@ package org.apache.spark.scheduler
import org.apache.spark.TaskContext
-class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) {
+class FakeTask(
+ stageId: Int,
+ prefLocs: Seq[TaskLocation] = Nil)
+ extends Task[Int](stageId, 0, 0, Seq.empty) {
override def runTask(context: TaskContext): Int = 0
-
override def preferredLocations: Seq[TaskLocation] = prefLocs
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
index 383855caef..f333247924 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
@@ -25,7 +25,7 @@ import org.apache.spark.TaskContext
* A Task implementation that fails to serialize.
*/
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
- extends Task[Array[Byte]](stageId, 0, 0) {
+ extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 9201d1e1f3..450ab7b9fe 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -57,8 +57,9 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
- val task = new ResultTask[String, String](0, 0,
- sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
+ val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array)
+ val task = new ResultTask[String, String](
+ 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty)
intercept[RuntimeException] {
task.run(0, 0, null)
}
@@ -66,7 +67,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}
test("all TaskCompletionListeners should be called even if some fail") {
- val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+ val context = TaskContext.empty()
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("blah"))
context.addTaskCompletionListener(listener)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 3abb99c4b2..f7cc4bb61d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -136,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
/**
* A Task implementation that results in a large serialized task.
*/
-class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) {
+class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024)
val random = new Random(0)
random.nextBytes(randomBuffer)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
index db718ecabb..05b3afef5b 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
@@ -138,7 +138,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {
shuffleHandle,
reduceId,
reduceId + 1,
- new TaskContextImpl(0, 0, 0, 0, null, null),
+ TaskContext.empty(),
blockManager,
mapOutputTracker)
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index cf8bd8ae69..828153bdbf 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -29,7 +29,7 @@ import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.PrivateMethodTester
-import org.apache.spark.{SparkFunSuite, TaskContextImpl}
+import org.apache.spark.{SparkFunSuite, TaskContext}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.shuffle.BlockFetchingListener
@@ -95,7 +95,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0, 0, null, null),
+ TaskContext.empty(),
transfer,
blockManager,
blocksByAddress,
@@ -165,7 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null)
+ val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
@@ -227,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null)
+ val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
new file mode 100644
index 0000000000..98f9314f31
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS}
+
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite, Success}
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler._
+import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab}
+import org.apache.spark.ui.scope.RDDOperationGraphListener
+
+class StagePageSuite extends SparkFunSuite with LocalSparkContext {
+
+ test("peak execution memory only displayed if unsafe is enabled") {
+ val unsafeConf = "spark.sql.unsafe.enabled"
+ val conf = new SparkConf().set(unsafeConf, "true")
+ val html = renderStagePage(conf).toString().toLowerCase
+ val targetString = "peak execution memory"
+ assert(html.contains(targetString))
+ // Disable unsafe and make sure it's not there
+ val conf2 = new SparkConf().set(unsafeConf, "false")
+ val html2 = renderStagePage(conf2).toString().toLowerCase
+ assert(!html2.contains(targetString))
+ }
+
+ /**
+ * Render a stage page started with the given conf and return the HTML.
+ * This also runs a dummy stage to populate the page with useful content.
+ */
+ private def renderStagePage(conf: SparkConf): Seq[Node] = {
+ val jobListener = new JobProgressListener(conf)
+ val graphListener = new RDDOperationGraphListener(conf)
+ val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
+ val request = mock(classOf[HttpServletRequest])
+ when(tab.conf).thenReturn(conf)
+ when(tab.progressListener).thenReturn(jobListener)
+ when(tab.operationGraphListener).thenReturn(graphListener)
+ when(tab.appName).thenReturn("testing")
+ when(tab.headerTabs).thenReturn(Seq.empty)
+ when(request.getParameter("id")).thenReturn("0")
+ when(request.getParameter("attempt")).thenReturn("0")
+ val page = new StagePage(tab)
+
+ // Simulate a stage in job progress listener
+ val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details")
+ val taskInfo = new TaskInfo(0, 0, 0, 0, "0", "localhost", TaskLocality.ANY, false)
+ jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
+ jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
+ taskInfo.markSuccessful()
+ jobListener.onTaskEnd(
+ SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty))
+ jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
+ page.render(request)
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 9c362f0de7..12e9bafcc9 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -399,4 +399,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
sc.stop()
}
+ test("external aggregation updates peak execution memory") {
+ val conf = createSparkConf(loadDefaults = false)
+ .set("spark.shuffle.memoryFraction", "0.001")
+ .set("spark.shuffle.manager", "hash") // make sure we're not also using ExternalSorter
+ sc = new SparkContext("local", "test", conf)
+ // No spilling
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map without spilling") {
+ sc.parallelize(1 to 10, 2).map { i => (i, i) }.reduceByKey(_ + _).count()
+ }
+ // With spilling
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map with spilling") {
+ sc.parallelize(1 to 1000 * 1000, 2).map { i => (i, i) }.reduceByKey(_ + _).count()
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 986cd8623d..bdb0f4d507 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -692,7 +692,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
sortWithoutBreakingSortingContracts(createSparkConf(true, false))
}
- def sortWithoutBreakingSortingContracts(conf: SparkConf) {
+ private def sortWithoutBreakingSortingContracts(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.01")
conf.set("spark.shuffle.manager", "sort")
sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
@@ -743,5 +743,15 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
}
sorter2.stop()
- }
+ }
+
+ test("sorting updates peak execution memory") {
+ val conf = createSparkConf(loadDefaults = false, kryo = false)
+ .set("spark.shuffle.manager", "sort")
+ sc = new SparkContext("local", "test", conf)
+ // Avoid aggregating here to make sure we're not also using ExternalAppendOnlyMap
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter") {
+ sc.parallelize(1 to 1000, 2).repartition(100).count()
+ }
+ }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 5e4c6232c9..193906d247 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -106,6 +106,13 @@ final class UnsafeExternalRowSorter {
sorter.spill();
}
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ public long getPeakMemoryUsage() {
+ return sorter.getPeakMemoryUsedBytes();
+ }
+
private void cleanupResources() {
sorter.freeMemory();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 9e2c9334a7..43d06ce9bd 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -209,6 +209,14 @@ public final class UnsafeFixedWidthAggregationMap {
}
/**
+ * The memory used by this map's managed structures, in bytes.
+ * Note that this is also the peak memory used by this map, since the map is append-only.
+ */
+ public long getMemoryUsage() {
+ return map.getTotalMemoryConsumption();
+ }
+
+ /**
* Free the memory associated with this map. This is idempotent and can be called multiple times.
*/
public void free() {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index cd87b8deba..bf4905dc1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import java.io.IOException
-import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -263,11 +263,12 @@ case class GeneratedAggregate(
assert(iter.hasNext, "There should be at least one row for this path")
log.info("Using Unsafe-based aggregator")
val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
+ val taskContext = TaskContext.get()
val aggregationMap = new UnsafeFixedWidthAggregationMap(
newAggregationBuffer(EmptyRow),
aggregationBufferSchema,
groupKeySchema,
- TaskContext.get.taskMemoryManager(),
+ taskContext.taskMemoryManager(),
SparkEnv.get.shuffleMemoryManager,
1024 * 16, // initial capacity
pageSizeBytes,
@@ -284,6 +285,10 @@ case class GeneratedAggregate(
updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
}
+ // Record memory used in the process
+ taskContext.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(aggregationMap.getMemoryUsage)
+
new Iterator[InternalRow] {
private[this] val mapIterator = aggregationMap.iterator()
private[this] val resultProjection = resultProjectionBuilder()
@@ -300,7 +305,7 @@ case class GeneratedAggregate(
} else {
// This is the last element in the iterator, so let's free the buffer. Before we do,
// though, we need to make a defensive copy of the result so that we don't return an
- // object that might contain dangling pointers to the freed memory
+ // object that might contain dangling pointers to the freed memory.
val resultCopy = result.copy()
aggregationMap.free()
resultCopy
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 624efc1b1d..e73e2523a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
import scala.concurrent._
import scala.concurrent.duration._
+import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -70,7 +71,14 @@ case class BroadcastHashJoin(
val broadcastRelation = Await.result(broadcastFuture, timeout)
streamedPlan.execute().mapPartitions { streamedIter =>
- hashJoin(streamedIter, broadcastRelation.value)
+ val hashedRelation = broadcastRelation.value
+ hashedRelation match {
+ case unsafe: UnsafeHashedRelation =>
+ TaskContext.get().internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+ case _ =>
+ }
+ hashJoin(streamedIter, hashedRelation)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index 309716a0ef..c35e439cc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
import scala.concurrent._
import scala.concurrent.duration._
+import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -75,6 +76,13 @@ case class BroadcastHashOuterJoin(
val hashTable = broadcastRelation.value
val keyGenerator = streamedKeyGenerator
+ hashTable match {
+ case unsafe: UnsafeHashedRelation =>
+ TaskContext.get().internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+ case _ =>
+ }
+
joinType match {
case LeftOuter =>
streamedIter.flatMap(currentRow => {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index a60593911f..5bd06fbdca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.joins
+import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -51,7 +52,14 @@ case class BroadcastLeftSemiJoinHash(
val broadcastedRelation = sparkContext.broadcast(hashRelation)
left.execute().mapPartitions { streamIter =>
- hashSemiJoin(streamIter, broadcastedRelation.value)
+ val hashedRelation = broadcastedRelation.value
+ hashedRelation match {
+ case unsafe: UnsafeHashedRelation =>
+ TaskContext.get().internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+ case _ =>
+ }
+ hashSemiJoin(streamIter, hashedRelation)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index cc8bbfd2f8..58b4236f7b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -183,8 +183,27 @@ private[joins] final class UnsafeHashedRelation(
private[joins] def this() = this(null) // Needed for serialization
// Use BytesToBytesMap in executor for better performance (it's created when deserialization)
+ // This is used in broadcast joins and distributed mode only
@transient private[this] var binaryMap: BytesToBytesMap = _
+ /**
+ * Return the size of the unsafe map on the executors.
+ *
+ * For broadcast joins, this hashed relation is bigger on the driver because it is
+ * represented as a Java hash map there. While serializing the map to the executors,
+ * however, we rehash the contents in a binary map to reduce the memory footprint on
+ * the executors.
+ *
+ * For non-broadcast joins or in local mode, return 0.
+ */
+ def getUnsafeSize: Long = {
+ if (binaryMap != null) {
+ binaryMap.getTotalMemoryConsumption
+ } else {
+ 0
+ }
+ }
+
override def get(key: InternalRow): Seq[InternalRow] = {
val unsafeKey = key.asInstanceOf[UnsafeRow]
@@ -214,7 +233,7 @@ private[joins] final class UnsafeHashedRelation(
}
} else {
- // Use the JavaHashMap in Local mode or ShuffleHashJoin
+ // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin)
hashTable.get(unsafeKey)
}
}
@@ -316,6 +335,7 @@ private[joins] object UnsafeHashedRelation {
keyGenerator: UnsafeProjection,
sizeEstimate: Int): HashedRelation = {
+ // Use a Java hash table here because unsafe maps expect fixed size records
val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
// Create a mapping of buildKeys -> rows
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index 92cf328c76..3192b6ebe9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
@@ -76,6 +77,11 @@ case class ExternalSort(
val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering))
sorter.insertAll(iterator.map(r => (r.copy(), null)))
val baseIterator = sorter.iterator.map(_._1)
+ val context = TaskContext.get()
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
// TODO(marmbrus): The complex type signature below thwarts inference for no reason.
CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
}, preservesPartitioning = true)
@@ -137,7 +143,11 @@ case class TungstenSort(
if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
}
- sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+ val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+ val taskContext = TaskContext.get()
+ taskContext.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
+ sortedIterator
}, preservesPartitioning = true)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index f1abae0720..29dfcf2575 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -21,6 +21,7 @@ import java.sql.Timestamp
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.errors.DialectException
@@ -258,6 +259,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
}
}
+ private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = {
+ val df = sql(sqlText)
+ // First, check if we have GeneratedAggregate.
+ val hasGeneratedAgg = df.queryExecution.executedPlan
+ .collect { case _: GeneratedAggregate | _: aggregate.Aggregate => true }
+ .nonEmpty
+ if (!hasGeneratedAgg) {
+ fail(
+ s"""
+ |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan.
+ |${df.queryExecution.simpleString}
+ """.stripMargin)
+ }
+ // Then, check results.
+ checkAnswer(df, expectedResults)
+ }
+
test("aggregation with codegen") {
val originalValue = sqlContext.conf.codegenEnabled
sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
@@ -267,26 +285,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
.unionAll(sqlContext.table("testData"))
.registerTempTable("testData3x")
- def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = {
- val df = sql(sqlText)
- // First, check if we have GeneratedAggregate.
- var hasGeneratedAgg = false
- df.queryExecution.executedPlan.foreach {
- case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
- case newAggregate: aggregate.Aggregate => hasGeneratedAgg = true
- case _ =>
- }
- if (!hasGeneratedAgg) {
- fail(
- s"""
- |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan.
- |${df.queryExecution.simpleString}
- """.stripMargin)
- }
- // Then, check results.
- checkAnswer(df, expectedResults)
- }
-
try {
// Just to group rows.
testCodeGen(
@@ -1605,6 +1603,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123))))
}
+ test("aggregation with codegen updates peak execution memory") {
+ withSQLConf(
+ (SQLConf.CODEGEN_ENABLED.key, "true"),
+ (SQLConf.USE_SQL_AGGREGATE2.key, "false")) {
+ val sc = sqlContext.sparkContext
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") {
+ testCodeGen(
+ "SELECT key, count(value) FROM testData GROUP BY key",
+ (1 to 100).map(i => Row(i, 1)))
+ }
+ }
+ }
+
+ test("external sorting updates peak execution memory") {
+ withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) {
+ val sc = sqlContext.sparkContext
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sort") {
+ sortTest()
+ }
+ }
+ }
+
test("SPARK-9511: error with table starting with number") {
val df = sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString))
.toDF("num", "str")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
index c794984851..88bce0e319 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -21,6 +21,7 @@ import scala.util.Random
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.test.TestSQLContext
@@ -59,6 +60,17 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
)
}
+ test("sorting updates peak execution memory") {
+ val sc = TestSQLContext.sparkContext
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") {
+ checkThatPlansAgree(
+ (1 to 100).map(v => Tuple1(v)).toDF("a"),
+ (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child),
+ (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child),
+ sortAnswers = false)
+ }
+ }
+
// Test sorting on different data types
for (
dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 7c591f6143..ef827b0fe9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -69,7 +69,8 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
taskAttemptId = Random.nextInt(10000),
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
- metricsSystem = null))
+ metricsSystem = null,
+ internalAccumulators = Seq.empty))
try {
f
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 0282b25b9d..601a5a07ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -76,7 +76,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite {
taskAttemptId = 98456,
attemptNumber = 0,
taskMemoryManager = taskMemMgr,
- metricsSystem = null))
+ metricsSystem = null,
+ internalAccumulators = Seq.empty))
// Create the data converters
val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
new file mode 100644
index 0000000000..0554e11d25
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -0,0 +1,94 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+// TODO: uncomment the test here! It is currently failing due to
+// bad interaction with org.apache.spark.sql.test.TestSQLContext.
+
+// scalastyle:off
+//package org.apache.spark.sql.execution.joins
+//
+//import scala.reflect.ClassTag
+//
+//import org.scalatest.BeforeAndAfterAll
+//
+//import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
+//import org.apache.spark.sql.functions._
+//import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest}
+//
+///**
+// * Test various broadcast join operators with unsafe enabled.
+// *
+// * This needs to be its own suite because [[org.apache.spark.sql.test.TestSQLContext]] runs
+// * in local mode, but for tests in this suite we need to run Spark in local-cluster mode.
+// * In particular, the use of [[org.apache.spark.unsafe.map.BytesToBytesMap]] in
+// * [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered without
+// * serializing the hashed relation, which does not happen in local mode.
+// */
+//class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
+// private var sc: SparkContext = null
+// private var sqlContext: SQLContext = null
+//
+// /**
+// * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled.
+// */
+// override def beforeAll(): Unit = {
+// super.beforeAll()
+// val conf = new SparkConf()
+// .setMaster("local-cluster[2,1,1024]")
+// .setAppName("testing")
+// sc = new SparkContext(conf)
+// sqlContext = new SQLContext(sc)
+// sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true)
+// sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
+// }
+//
+// override def afterAll(): Unit = {
+// sc.stop()
+// sc = null
+// sqlContext = null
+// }
+//
+// /**
+// * Test whether the specified broadcast join updates the peak execution memory accumulator.
+// */
+// private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = {
+// AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) {
+// val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
+// val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
+// // Comparison at the end is for broadcast left semi join
+// val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
+// val df3 = df1.join(broadcast(df2), joinExpression, joinType)
+// val plan = df3.queryExecution.executedPlan
+// assert(plan.collect { case p: T => p }.size === 1)
+// plan.executeCollect()
+// }
+// }
+//
+// test("unsafe broadcast hash join updates peak execution memory") {
+// testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner")
+// }
+//
+// test("unsafe broadcast hash outer join updates peak execution memory") {
+// testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer")
+// }
+//
+// test("unsafe broadcast left semi join updates peak execution memory") {
+// testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi")
+// }
+//
+//}
+// scalastyle:on