aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala77
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala17
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala65
-rw-r--r--core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala9
-rw-r--r--tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala3
11 files changed, 164 insertions, 82 deletions
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 56cd8723a3..11a6e10243 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -190,10 +190,10 @@ class ShuffleWriteMetrics extends Serializable {
/**
* Number of bytes written for the shuffle by this task
*/
- var shuffleBytesWritten: Long = _
+ @volatile var shuffleBytesWritten: Long = _
/**
* Time the task spent blocking on writes to disk or buffer cache, in nanoseconds
*/
- var shuffleWriteTime: Long = _
+ @volatile var shuffleWriteTime: Long = _
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 45d3b8b9b8..51e454d931 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -39,10 +39,14 @@ private[spark] class HashShuffleWriter[K, V](
// we don't try deleting files, etc twice.
private var stopping = false
+ private val writeMetrics = new ShuffleWriteMetrics()
+ metrics.shuffleWriteMetrics = Some(writeMetrics)
+
private val blockManager = SparkEnv.get.blockManager
private val shuffleBlockManager = blockManager.shuffleBlockManager
private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
- private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser)
+ private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser,
+ writeMetrics)
/** Write a bunch of records to this task's output */
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
@@ -99,22 +103,12 @@ private[spark] class HashShuffleWriter[K, V](
private def commitWritesAndBuildStatus(): MapStatus = {
// Commit the writes. Get the size of each bucket block (total block size).
- var totalBytes = 0L
- var totalTime = 0L
val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commitAndClose()
val size = writer.fileSegment().length
- totalBytes += size
- totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
}
- // Update shuffle metrics.
- val shuffleMetrics = new ShuffleWriteMetrics
- shuffleMetrics.shuffleBytesWritten = totalBytes
- shuffleMetrics.shuffleWriteTime = totalTime
- metrics.shuffleWriteMetrics = Some(shuffleMetrics)
-
new MapStatus(blockManager.blockManagerId, compressedSizes)
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 24db2f287a..e54e6383d2 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -52,6 +52,9 @@ private[spark] class SortShuffleWriter[K, V, C](
private var mapStatus: MapStatus = null
+ private val writeMetrics = new ShuffleWriteMetrics()
+ context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics)
+
/** Write a bunch of records to this task's output */
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
// Get an iterator with the elements for each partition ID
@@ -84,13 +87,10 @@ private[spark] class SortShuffleWriter[K, V, C](
val offsets = new Array[Long](numPartitions + 1)
val lengths = new Array[Long](numPartitions)
- // Statistics
- var totalBytes = 0L
- var totalTime = 0L
-
for ((id, elements) <- partitions) {
if (elements.hasNext) {
- val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize)
+ val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize,
+ writeMetrics)
for (elem <- elements) {
writer.write(elem)
}
@@ -98,18 +98,12 @@ private[spark] class SortShuffleWriter[K, V, C](
val segment = writer.fileSegment()
offsets(id + 1) = segment.offset + segment.length
lengths(id) = segment.length
- totalTime += writer.timeWriting()
- totalBytes += segment.length
} else {
// The partition is empty; don't create a new writer to avoid writing headers, etc
offsets(id + 1) = offsets(id)
}
}
- val shuffleMetrics = new ShuffleWriteMetrics
- shuffleMetrics.shuffleBytesWritten = totalBytes
- shuffleMetrics.shuffleWriteTime = totalTime
- context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics)
context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 3876cf43e2..8d21b02b74 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props}
import sun.nio.ch.DirectBuffer
import org.apache.spark._
-import org.apache.spark.executor.{DataReadMethod, InputMetrics}
+import org.apache.spark.executor.{DataReadMethod, InputMetrics, ShuffleWriteMetrics}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.serializer.Serializer
@@ -562,17 +562,19 @@ private[spark] class BlockManager(
/**
* A short circuited method to get a block writer that can write data directly to disk.
- * The Block will be appended to the File specified by filename. This is currently used for
- * writing shuffle files out. Callers should handle error cases.
+ * The Block will be appended to the File specified by filename. Callers should handle error
+ * cases.
*/
def getDiskWriter(
blockId: BlockId,
file: File,
serializer: Serializer,
- bufferSize: Int): BlockObjectWriter = {
+ bufferSize: Int,
+ writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = {
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
- new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites)
+ new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites,
+ writeMetrics)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 01d46e1ffc..adda971fd7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -22,6 +22,7 @@ import java.nio.channels.FileChannel
import org.apache.spark.Logging
import org.apache.spark.serializer.{SerializationStream, Serializer}
+import org.apache.spark.executor.ShuffleWriteMetrics
/**
* An interface for writing JVM objects to some underlying storage. This interface allows
@@ -60,41 +61,26 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
* This is only valid after commitAndClose() has been called.
*/
def fileSegment(): FileSegment
-
- /**
- * Cumulative time spent performing blocking writes, in ns.
- */
- def timeWriting(): Long
-
- /**
- * Number of bytes written so far
- */
- def bytesWritten: Long
}
-/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
+/**
+ * BlockObjectWriter which writes directly to a file on disk. Appends to the given file.
+ * The given write metrics will be updated incrementally, but will not necessarily be current until
+ * commitAndClose is called.
+ */
private[spark] class DiskBlockObjectWriter(
blockId: BlockId,
file: File,
serializer: Serializer,
bufferSize: Int,
compressStream: OutputStream => OutputStream,
- syncWrites: Boolean)
+ syncWrites: Boolean,
+ writeMetrics: ShuffleWriteMetrics)
extends BlockObjectWriter(blockId)
with Logging
{
-
/** Intercepts write calls and tracks total time spent writing. Not thread safe. */
private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
- def timeWriting = _timeWriting
- private var _timeWriting = 0L
-
- private def callWithTiming(f: => Unit) = {
- val start = System.nanoTime()
- f
- _timeWriting += (System.nanoTime() - start)
- }
-
def write(i: Int): Unit = callWithTiming(out.write(i))
override def write(b: Array[Byte]) = callWithTiming(out.write(b))
override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
@@ -111,7 +97,11 @@ private[spark] class DiskBlockObjectWriter(
private val initialPosition = file.length()
private var finalPosition: Long = -1
private var initialized = false
- private var _timeWriting = 0L
+
+ /** Calling channel.position() to update the write metrics can be a little bit expensive, so we
+ * only call it every N writes */
+ private var writesSinceMetricsUpdate = 0
+ private var lastPosition = initialPosition
override def open(): BlockObjectWriter = {
fos = new FileOutputStream(file, true)
@@ -128,14 +118,11 @@ private[spark] class DiskBlockObjectWriter(
if (syncWrites) {
// Force outstanding writes to disk and track how long it takes
objOut.flush()
- val start = System.nanoTime()
- fos.getFD.sync()
- _timeWriting += System.nanoTime() - start
+ def sync = fos.getFD.sync()
+ callWithTiming(sync)
}
objOut.close()
- _timeWriting += ts.timeWriting
-
channel = null
bs = null
fos = null
@@ -153,6 +140,7 @@ private[spark] class DiskBlockObjectWriter(
// serializer stream and the lower level stream.
objOut.flush()
bs.flush()
+ updateBytesWritten()
close()
}
finalPosition = file.length()
@@ -162,6 +150,8 @@ private[spark] class DiskBlockObjectWriter(
// truncating the file to its initial position.
override def revertPartialWritesAndClose() {
try {
+ writeMetrics.shuffleBytesWritten -= (lastPosition - initialPosition)
+
if (initialized) {
objOut.flush()
bs.flush()
@@ -184,19 +174,36 @@ private[spark] class DiskBlockObjectWriter(
if (!initialized) {
open()
}
+
objOut.writeObject(value)
+
+ if (writesSinceMetricsUpdate == 32) {
+ writesSinceMetricsUpdate = 0
+ updateBytesWritten()
+ } else {
+ writesSinceMetricsUpdate += 1
+ }
}
override def fileSegment(): FileSegment = {
- new FileSegment(file, initialPosition, bytesWritten)
+ new FileSegment(file, initialPosition, finalPosition - initialPosition)
}
- // Only valid if called after close()
- override def timeWriting() = _timeWriting
+ private def updateBytesWritten() {
+ val pos = channel.position()
+ writeMetrics.shuffleBytesWritten += (pos - lastPosition)
+ lastPosition = pos
+ }
+
+ private def callWithTiming(f: => Unit) = {
+ val start = System.nanoTime()
+ f
+ writeMetrics.shuffleWriteTime += (System.nanoTime() - start)
+ }
- // Only valid if called after commit()
- override def bytesWritten: Long = {
- assert(finalPosition != -1, "bytesWritten is only valid after successful commit()")
- finalPosition - initialPosition
+ // For testing
+ private[spark] def flush() {
+ objOut.flush()
+ bs.flush()
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index f9fdffae8b..3565719b54 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -29,6 +29,7 @@ import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.executor.ShuffleWriteMetrics
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
private[spark] trait ShuffleWriterGroup {
@@ -111,7 +112,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
* Get a ShuffleWriterGroup for the given map task, which will register it as complete
* when the writers are closed successfully
*/
- def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
+ def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer,
+ writeMetrics: ShuffleWriteMetrics) = {
new ShuffleWriterGroup {
shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
private val shuffleState = shuffleStates(shuffleId)
@@ -121,7 +123,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
fileGroup = getUnusedFileGroup()
Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
- blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize)
+ blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize,
+ writeMetrics)
}
} else {
Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
@@ -136,7 +139,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
logWarning(s"Failed to remove existing shuffle file $blockFile")
}
}
- blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize)
+ blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 260a5c3888..9f85b94a70 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -31,6 +31,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.serializer.{DeserializationStream, Serializer}
import org.apache.spark.storage.{BlockId, BlockManager}
import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator
+import org.apache.spark.executor.ShuffleWriteMetrics
/**
* :: DeveloperApi ::
@@ -102,6 +103,10 @@ class ExternalAppendOnlyMap[K, V, C](
private var _diskBytesSpilled = 0L
private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024
+
+ // Write metrics for current spill
+ private var curWriteMetrics: ShuffleWriteMetrics = _
+
private val keyComparator = new HashComparator[K]
private val ser = serializer.newInstance()
@@ -172,7 +177,9 @@ class ExternalAppendOnlyMap[K, V, C](
logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)"
.format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
val (blockId, file) = diskBlockManager.createTempBlock()
- var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
+ curWriteMetrics = new ShuffleWriteMetrics()
+ var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize,
+ curWriteMetrics)
var objectsWritten = 0
// List of batch sizes (bytes) in the order they are written to disk
@@ -183,9 +190,8 @@ class ExternalAppendOnlyMap[K, V, C](
val w = writer
writer = null
w.commitAndClose()
- val bytesWritten = w.bytesWritten
- batchSizes.append(bytesWritten)
- _diskBytesSpilled += bytesWritten
+ _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
+ batchSizes.append(curWriteMetrics.shuffleBytesWritten)
objectsWritten = 0
}
@@ -199,7 +205,9 @@ class ExternalAppendOnlyMap[K, V, C](
if (objectsWritten == serializerBatchSize) {
flush()
- writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
+ curWriteMetrics = new ShuffleWriteMetrics()
+ writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize,
+ curWriteMetrics)
}
}
if (objectsWritten > 0) {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 3f93afd57b..eb4849ebc6 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -28,6 +28,7 @@ import com.google.common.io.ByteStreams
import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
import org.apache.spark.serializer.{DeserializationStream, Serializer}
import org.apache.spark.storage.BlockId
+import org.apache.spark.executor.ShuffleWriteMetrics
/**
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -112,11 +113,14 @@ private[spark] class ExternalSorter[K, V, C](
// What threshold of elementsRead we start estimating map size at.
private val trackMemoryThreshold = 1000
- // Spilling statistics
+ // Total spilling statistics
private var spillCount = 0
private var _memoryBytesSpilled = 0L
private var _diskBytesSpilled = 0L
+ // Write metrics for current spill
+ private var curWriteMetrics: ShuffleWriteMetrics = _
+
// How much of the shared memory pool this collection has claimed
private var myMemoryThreshold = 0L
@@ -239,7 +243,8 @@ private[spark] class ExternalSorter[K, V, C](
logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)"
.format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
val (blockId, file) = diskBlockManager.createTempBlock()
- var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize)
+ curWriteMetrics = new ShuffleWriteMetrics()
+ var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
var objectsWritten = 0 // Objects written since the last flush
// List of batch sizes (bytes) in the order they are written to disk
@@ -254,9 +259,8 @@ private[spark] class ExternalSorter[K, V, C](
val w = writer
writer = null
w.commitAndClose()
- val bytesWritten = w.bytesWritten
- batchSizes.append(bytesWritten)
- _diskBytesSpilled += bytesWritten
+ _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
+ batchSizes.append(curWriteMetrics.shuffleBytesWritten)
objectsWritten = 0
}
@@ -275,7 +279,8 @@ private[spark] class ExternalSorter[K, V, C](
if (objectsWritten == serializerBatchSize) {
flush()
- writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize)
+ curWriteMetrics = new ShuffleWriteMetrics()
+ writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
}
}
if (objectsWritten > 0) {
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
new file mode 100644
index 0000000000..bbc7e1357b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.storage
+
+import org.scalatest.FunSuite
+import java.io.File
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.SparkConf
+
+class BlockObjectWriterSuite extends FunSuite {
+ test("verify write metrics") {
+ val file = new File("somefile")
+ file.deleteOnExit()
+ val writeMetrics = new ShuffleWriteMetrics()
+ val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+ new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics)
+
+ writer.write(Long.box(20))
+ // Metrics don't update on every write
+ assert(writeMetrics.shuffleBytesWritten == 0)
+ // After 32 writes, metrics should update
+ for (i <- 0 until 32) {
+ writer.flush()
+ writer.write(Long.box(i))
+ }
+ assert(writeMetrics.shuffleBytesWritten > 0)
+ writer.commitAndClose()
+ assert(file.length() == writeMetrics.shuffleBytesWritten)
+ }
+
+ test("verify write metrics on revert") {
+ val file = new File("somefile")
+ file.deleteOnExit()
+ val writeMetrics = new ShuffleWriteMetrics()
+ val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+ new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics)
+
+ writer.write(Long.box(20))
+ // Metrics don't update on every write
+ assert(writeMetrics.shuffleBytesWritten == 0)
+ // After 32 writes, metrics should update
+ for (i <- 0 until 32) {
+ writer.flush()
+ writer.write(Long.box(i))
+ }
+ assert(writeMetrics.shuffleBytesWritten > 0)
+ writer.revertPartialWritesAndClose()
+ assert(writeMetrics.shuffleBytesWritten == 0)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index 985ac93947..b8299e2ea1 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.executor.ShuffleWriteMetrics
class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
private val testConf = new SparkConf(false)
@@ -153,7 +154,7 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
val shuffleManager = store.shuffleBlockManager
- val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer)
+ val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer, new ShuffleWriteMetrics)
for (writer <- shuffle1.writers) {
writer.write("test1")
writer.write("test2")
@@ -165,7 +166,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
val shuffle1Segment = shuffle1.writers(0).fileSegment()
shuffle1.releaseWriters(success = true)
- val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf))
+ val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf),
+ new ShuffleWriteMetrics)
for (writer <- shuffle2.writers) {
writer.write("test3")
@@ -183,7 +185,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
// of block based on remaining data in file : which could mess things up when there is concurrent read
// and writes happening to the same shuffle group.
- val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf))
+ val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf),
+ new ShuffleWriteMetrics)
for (writer <- shuffle3.writers) {
writer.write("test3")
writer.write("test4")
diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala
index 8a05fcb449..17bf7c2541 100644
--- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala
+++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicLong
import org.apache.spark.SparkContext
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.Utils
+import org.apache.spark.executor.ShuffleWriteMetrics
/**
* Internal utility for micro-benchmarking shuffle write performance.
@@ -56,7 +57,7 @@ object StoragePerfTester {
def writeOutputBytes(mapId: Int, total: AtomicLong) = {
val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits,
- new KryoSerializer(sc.conf))
+ new KryoSerializer(sc.conf), new ShuffleWriteMetrics())
val writers = shuffle.writers
for (i <- 1 to recordsPerMap) {
writers(i % numOutputSplits).write(writeData)