aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java8
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java2
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java4
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala (renamed from core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala)96
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala (renamed from core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala)2
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala52
16 files changed, 90 insertions, 114 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index d3d6280284..0b8b604e18 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -75,7 +75,7 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
private final Serializer serializer;
/** Array of file writers, one for each partition */
- private BlockObjectWriter[] partitionWriters;
+ private DiskBlockObjectWriter[] partitionWriters;
public BypassMergeSortShuffleWriter(
SparkConf conf,
@@ -101,7 +101,7 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
}
final SerializerInstance serInstance = serializer.newInstance();
final long openStartTime = System.nanoTime();
- partitionWriters = new BlockObjectWriter[numPartitions];
+ partitionWriters = new DiskBlockObjectWriter[numPartitions];
for (int i = 0; i < numPartitions; i++) {
final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
blockManager.diskBlockManager().createTempShuffleBlock();
@@ -121,7 +121,7 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
partitionWriters[partitioner.getPartition(key)].write(key, record._2());
}
- for (BlockObjectWriter writer : partitionWriters) {
+ for (DiskBlockObjectWriter writer : partitionWriters) {
writer.commitAndClose();
}
}
@@ -169,7 +169,7 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
if (partitionWriters != null) {
try {
final DiskBlockManager diskBlockManager = blockManager.diskBlockManager();
- for (BlockObjectWriter writer : partitionWriters) {
+ for (DiskBlockObjectWriter writer : partitionWriters) {
// This method explicitly does _not_ throw exceptions:
writer.revertPartialWritesAndClose();
if (!diskBlockManager.getFile(writer.blockId()).delete()) {
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 5628957320..1d460432be 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -157,7 +157,7 @@ final class UnsafeShuffleExternalSorter {
// Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
// after SPARK-5581 is fixed.
- BlockObjectWriter writer;
+ DiskBlockObjectWriter writer;
// Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
// be an API to directly transfer bytes from managed memory to the disk writer, we buffer
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index b8d6665980..71eed29563 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -26,7 +26,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManager;
-import org.apache.spark.storage.BlockObjectWriter;
+import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.TempLocalBlockId;
import org.apache.spark.unsafe.PlatformDependent;
@@ -47,7 +47,7 @@ final class UnsafeSorterSpillWriter {
private final File file;
private final BlockId blockId;
private final int numRecordsToWrite;
- private BlockObjectWriter writer;
+ private DiskBlockObjectWriter writer;
private int numRecordsSpilled = 0;
public UnsafeSorterSpillWriter(
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index 6c3b3080d2..f6a96d81e7 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVecto
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
private[spark] trait ShuffleWriterGroup {
- val writers: Array[BlockObjectWriter]
+ val writers: Array[DiskBlockObjectWriter]
/** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
def releaseWriters(success: Boolean)
@@ -113,15 +113,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
val openStartTime = System.nanoTime
val serializerInstance = serializer.newInstance()
- val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
+ val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) {
fileGroup = getUnusedFileGroup()
- Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize,
writeMetrics)
}
} else {
- Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
val blockFile = blockManager.diskBlockManager.getFile(blockId)
// Because of previous failures, the shuffle file may already exist on this machine.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index d9c63b6e7b..fae69551e7 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -114,7 +114,7 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
}
private[spark] object IndexShuffleBlockResolver {
- // No-op reduce ID used in interactions with disk store and BlockObjectWriter.
+ // No-op reduce ID used in interactions with disk store and DiskBlockObjectWriter.
// The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort
// shuffle outputs for several reduces are glommed into a single file.
// TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index eb87cee159..41df70c602 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -22,7 +22,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._
-import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.storage.DiskBlockObjectWriter
private[spark] class HashShuffleWriter[K, V](
shuffleBlockResolver: FileShuffleBlockResolver,
@@ -102,7 +102,7 @@ private[spark] class HashShuffleWriter[K, V](
private def commitWritesAndBuildStatus(): MapStatus = {
// Commit the writes. Get the size of each bucket block (total block size).
- val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter =>
+ val sizes: Array[Long] = shuffle.writers.map { writer: DiskBlockObjectWriter =>
writer.commitAndClose()
writer.fileSegment().length
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 1beafa1771..86493673d9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -648,7 +648,7 @@ private[spark] class BlockManager(
file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
- writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = {
+ writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 7eeabd1e04..49d9154f95 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -26,66 +26,25 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.util.Utils
/**
- * An interface for writing JVM objects to some underlying storage. This interface allows
- * appending data to an existing block, and can guarantee atomicity in the case of faults
- * as it allows the caller to revert partial writes.
+ * A class for writing JVM objects directly to a file on disk. This class allows data to be appended
+ * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to
+ * revert partial writes.
*
- * This interface does not support concurrent writes. Also, once the writer has
- * been opened, it cannot be reopened again.
- */
-private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream {
-
- def open(): BlockObjectWriter
-
- def close()
-
- def isOpen: Boolean
-
- /**
- * Flush the partial writes and commit them as a single atomic block.
- */
- def commitAndClose(): Unit
-
- /**
- * Reverts writes that haven't been flushed yet. Callers should invoke this function
- * when there are runtime exceptions. This method will not throw, though it may be
- * unsuccessful in truncating written data.
- */
- def revertPartialWritesAndClose()
-
- /**
- * Writes a key-value pair.
- */
- def write(key: Any, value: Any)
-
- /**
- * Notify the writer that a record worth of bytes has been written with OutputStream#write.
- */
- def recordWritten()
-
- /**
- * Returns the file segment of committed data that this Writer has written.
- * This is only valid after commitAndClose() has been called.
- */
- def fileSegment(): FileSegment
-}
-
-/**
- * BlockObjectWriter which writes directly to a file on disk. Appends to the given file.
+ * This class does not support concurrent writes. Also, once the writer has been opened it cannot be
+ * reopened again.
*/
private[spark] class DiskBlockObjectWriter(
- blockId: BlockId,
+ val blockId: BlockId,
file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
compressStream: OutputStream => OutputStream,
syncWrites: Boolean,
- // These write metrics concurrently shared with other active BlockObjectWriter's who
+ // These write metrics concurrently shared with other active DiskBlockObjectWriters who
// are themselves performing writes. All updates must be relative.
writeMetrics: ShuffleWriteMetrics)
- extends BlockObjectWriter(blockId)
- with Logging
-{
+ extends OutputStream
+ with Logging {
/** The file channel, used for repositioning / truncating the file. */
private var channel: FileChannel = null
@@ -122,7 +81,7 @@ private[spark] class DiskBlockObjectWriter(
*/
private var numRecordsWritten = 0
- override def open(): BlockObjectWriter = {
+ def open(): DiskBlockObjectWriter = {
if (hasBeenClosed) {
throw new IllegalStateException("Writer already closed. Cannot be reopened.")
}
@@ -159,9 +118,12 @@ private[spark] class DiskBlockObjectWriter(
}
}
- override def isOpen: Boolean = objOut != null
+ def isOpen: Boolean = objOut != null
- override def commitAndClose(): Unit = {
+ /**
+ * Flush the partial writes and commit them as a single atomic block.
+ */
+ def commitAndClose(): Unit = {
if (initialized) {
// NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the
// serializer stream and the lower level stream.
@@ -177,9 +139,15 @@ private[spark] class DiskBlockObjectWriter(
commitAndCloseHasBeenCalled = true
}
- // Discard current writes. We do this by flushing the outstanding writes and then
- // truncating the file to its initial position.
- override def revertPartialWritesAndClose() {
+
+ /**
+ * Reverts writes that haven't been flushed yet. Callers should invoke this function
+ * when there are runtime exceptions. This method will not throw, though it may be
+ * unsuccessful in truncating written data.
+ */
+ def revertPartialWritesAndClose() {
+ // Discard current writes. We do this by flushing the outstanding writes and then
+ // truncating the file to its initial position.
try {
if (initialized) {
writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
@@ -201,7 +169,10 @@ private[spark] class DiskBlockObjectWriter(
}
}
- override def write(key: Any, value: Any) {
+ /**
+ * Writes a key-value pair.
+ */
+ def write(key: Any, value: Any) {
if (!initialized) {
open()
}
@@ -221,7 +192,10 @@ private[spark] class DiskBlockObjectWriter(
bs.write(kvBytes, offs, len)
}
- override def recordWritten(): Unit = {
+ /**
+ * Notify the writer that a record worth of bytes has been written with OutputStream#write.
+ */
+ def recordWritten(): Unit = {
numRecordsWritten += 1
writeMetrics.incShuffleRecordsWritten(1)
@@ -230,7 +204,11 @@ private[spark] class DiskBlockObjectWriter(
}
}
- override def fileSegment(): FileSegment = {
+ /**
+ * Returns the file segment of committed data that this Writer has written.
+ * This is only valid after commitAndClose() has been called.
+ */
+ def fileSegment(): FileSegment = {
if (!commitAndCloseHasBeenCalled) {
throw new IllegalStateException(
"fileSegment() is only valid after commitAndClose() has been called")
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
index 516aaa44d0..ae60f3b0cb 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
@@ -37,7 +37,7 @@ private[spark] class ChainedBuffer(chunkSize: Int) {
private var _size: Long = 0
/**
- * Feed bytes from this buffer into a BlockObjectWriter.
+ * Feed bytes from this buffer into a DiskBlockObjectWriter.
*
* @param pos Offset in the buffer to read from.
* @param os OutputStream to read into.
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 757dec66c2..ba7ec834d6 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -30,7 +30,7 @@ import org.apache.spark._
import org.apache.spark.serializer._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter}
-import org.apache.spark.storage.{BlockId, BlockObjectWriter}
+import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
/**
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -250,7 +250,7 @@ private[spark] class ExternalSorter[K, V, C](
// These variables are reset after each flush
var objectsWritten: Long = 0
var spillMetrics: ShuffleWriteMetrics = null
- var writer: BlockObjectWriter = null
+ var writer: DiskBlockObjectWriter = null
def openWriter(): Unit = {
assert (writer == null && spillMetrics == null)
spillMetrics = new ShuffleWriteMetrics
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
index 04bb7fc78c..f5844d5353 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
@@ -19,7 +19,6 @@ package org.apache.spark.util.collection
import java.util.Comparator
-import org.apache.spark.storage.BlockObjectWriter
import org.apache.spark.util.collection.WritablePartitionedPairCollection._
/**
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
index ae9a48729e..87a786b02d 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
@@ -21,9 +21,8 @@ import java.io.InputStream
import java.nio.IntBuffer
import java.util.Comparator
-import org.apache.spark.SparkEnv
import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance}
-import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.storage.DiskBlockObjectWriter
import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
/**
@@ -136,7 +135,7 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
// current position in the meta buffer in ints
var pos = 0
- def writeNext(writer: BlockObjectWriter): Unit = {
+ def writeNext(writer: DiskBlockObjectWriter): Unit = {
val keyStart = getKeyStartPos(metaBuffer, pos)
val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN)
pos += RECORD_SIZE
diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
index 7bc5989865..38848e9018 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -19,7 +19,7 @@ package org.apache.spark.util.collection
import java.util.Comparator
-import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.storage.DiskBlockObjectWriter
/**
* A common interface for size-tracking collections of key-value pairs that
@@ -51,7 +51,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
new WritablePartitionedIterator {
private[this] var cur = if (it.hasNext) it.next() else null
- def writeNext(writer: BlockObjectWriter): Unit = {
+ def writeNext(writer: DiskBlockObjectWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (it.hasNext) it.next() else null
}
@@ -91,11 +91,11 @@ private[spark] object WritablePartitionedPairCollection {
}
/**
- * Iterator that writes elements to a BlockObjectWriter instead of returning them. Each element
+ * Iterator that writes elements to a DiskBlockObjectWriter instead of returning them. Each element
* has an associated partition.
*/
private[spark] trait WritablePartitionedIterator {
- def writeNext(writer: BlockObjectWriter): Unit
+ def writeNext(writer: DiskBlockObjectWriter): Unit
def hasNext(): Boolean
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 542f8f4512..cc7342f1ec 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -68,8 +68,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
any[SerializerInstance],
anyInt(),
any[ShuffleWriteMetrics]
- )).thenAnswer(new Answer[BlockObjectWriter] {
- override def answer(invocation: InvocationOnMock): BlockObjectWriter = {
+ )).thenAnswer(new Answer[DiskBlockObjectWriter] {
+ override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = {
val args = invocation.getArguments
new DiskBlockObjectWriter(
args(0).asInstanceOf[BlockId],
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
index 7bdea724fe..66af6e1a79 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.Utils
-class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
+class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
var tempDir: File = _
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
index 6d2459d48d..3b67f62064 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
@@ -17,15 +17,20 @@
package org.apache.spark.util.collection
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import com.google.common.io.ByteStreams
+import org.mockito.Matchers.any
+import org.mockito.Mockito._
+import org.mockito.Mockito.RETURNS_SMART_NULLS
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
import org.scalatest.Matchers._
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.KryoSerializer
-import org.apache.spark.storage.{FileSegment, BlockObjectWriter}
+import org.apache.spark.storage.DiskBlockObjectWriter
class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
test("OrderedInputStream single record") {
@@ -79,13 +84,13 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
val struct = SomeStruct("something", 5)
buffer.insert(4, 10, struct)
val it = buffer.destructiveSortedWritablePartitionedIterator(None)
- val writer = new SimpleBlockObjectWriter
+ val (writer, baos) = createMockWriter()
assert(it.hasNext)
it.nextPartition should be (4)
it.writeNext(writer)
assert(!it.hasNext)
- val stream = serializerInstance.deserializeStream(writer.getInputStream)
+ val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
stream.readObject[AnyRef]() should be (10)
stream.readObject[AnyRef]() should be (struct)
}
@@ -101,7 +106,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
buffer.insert(5, 3, struct3)
val it = buffer.destructiveSortedWritablePartitionedIterator(None)
- val writer = new SimpleBlockObjectWriter
+ val (writer, baos) = createMockWriter()
assert(it.hasNext)
it.nextPartition should be (4)
it.writeNext(writer)
@@ -113,7 +118,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
it.writeNext(writer)
assert(!it.hasNext)
- val stream = serializerInstance.deserializeStream(writer.getInputStream)
+ val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
val iter = stream.asIterator
iter.next() should be (2)
iter.next() should be (struct2)
@@ -123,26 +128,21 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
iter.next() should be (struct1)
assert(!iter.hasNext)
}
-}
-
-case class SomeStruct(val str: String, val num: Int)
-
-class SimpleBlockObjectWriter extends BlockObjectWriter(null) {
- val baos = new ByteArrayOutputStream()
- override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
- baos.write(bytes, offs, len)
+ def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = {
+ val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS)
+ val baos = new ByteArrayOutputStream()
+ when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocationOnMock: InvocationOnMock): Unit = {
+ val args = invocationOnMock.getArguments
+ val bytes = args(0).asInstanceOf[Array[Byte]]
+ val offset = args(1).asInstanceOf[Int]
+ val length = args(2).asInstanceOf[Int]
+ baos.write(bytes, offset, length)
+ }
+ })
+ (writer, baos)
}
-
- def getInputStream(): InputStream = new ByteArrayInputStream(baos.toByteArray)
-
- override def open(): BlockObjectWriter = this
- override def close(): Unit = { }
- override def isOpen: Boolean = true
- override def commitAndClose(): Unit = { }
- override def revertPartialWritesAndClose(): Unit = { }
- override def fileSegment(): FileSegment = null
- override def write(key: Any, value: Any): Unit = { }
- override def recordWritten(): Unit = { }
- override def write(b: Int): Unit = { }
}
+
+case class SomeStruct(str: String, num: Int)