aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-07-14 12:56:17 -0700
committerReynold Xin <rxin@databricks.com>2015-07-14 12:56:17 -0700
commitd267c2834a639aaebd0559355c6a82613abb689b (patch)
treed96f025b32797c82611332d1499986f9f081f5de /core
parent8fb3a65cbb714120d612e58ef9d12b0521a83260 (diff)
downloadspark-d267c2834a639aaebd0559355c6a82613abb689b.tar.gz
spark-d267c2834a639aaebd0559355c6a82613abb689b.tar.bz2
spark-d267c2834a639aaebd0559355c6a82613abb689b.zip
[SPARK-9031] Merge BlockObjectWriter and DiskBlockObject writer to remove abstract class
BlockObjectWriter has only one concrete non-test class, DiskBlockObjectWriter. In order to simplify the code in preparation for other refactorings, I think that we should remove this base class and have only DiskBlockObjectWriter. While at one time we may have planned to have multiple BlockObjectWriter implementations, that doesn't seem to have happened, so the extra abstraction seems unnecessary. Author: Josh Rosen <joshrosen@databricks.com> Closes #7391 from JoshRosen/shuffle-write-interface-refactoring and squashes the following commits: c418e33 [Josh Rosen] Fix compilation 5047995 [Josh Rosen] Fix comments d5dc548 [Josh Rosen] Update references in comments 89dc797 [Josh Rosen] Rename test suite. 5755918 [Josh Rosen] Remove unnecessary val in case class 1607c91 [Josh Rosen] Merge BlockObjectWriter and DiskBlockObjectWriter
Diffstat (limited to 'core')
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java8
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java2
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java4
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala (renamed from core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala)96
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala (renamed from core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala)2
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala52
16 files changed, 90 insertions, 114 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index d3d6280284..0b8b604e18 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -75,7 +75,7 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
private final Serializer serializer;
/** Array of file writers, one for each partition */
- private BlockObjectWriter[] partitionWriters;
+ private DiskBlockObjectWriter[] partitionWriters;
public BypassMergeSortShuffleWriter(
SparkConf conf,
@@ -101,7 +101,7 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
}
final SerializerInstance serInstance = serializer.newInstance();
final long openStartTime = System.nanoTime();
- partitionWriters = new BlockObjectWriter[numPartitions];
+ partitionWriters = new DiskBlockObjectWriter[numPartitions];
for (int i = 0; i < numPartitions; i++) {
final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
blockManager.diskBlockManager().createTempShuffleBlock();
@@ -121,7 +121,7 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
partitionWriters[partitioner.getPartition(key)].write(key, record._2());
}
- for (BlockObjectWriter writer : partitionWriters) {
+ for (DiskBlockObjectWriter writer : partitionWriters) {
writer.commitAndClose();
}
}
@@ -169,7 +169,7 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
if (partitionWriters != null) {
try {
final DiskBlockManager diskBlockManager = blockManager.diskBlockManager();
- for (BlockObjectWriter writer : partitionWriters) {
+ for (DiskBlockObjectWriter writer : partitionWriters) {
// This method explicitly does _not_ throw exceptions:
writer.revertPartialWritesAndClose();
if (!diskBlockManager.getFile(writer.blockId()).delete()) {
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 5628957320..1d460432be 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -157,7 +157,7 @@ final class UnsafeShuffleExternalSorter {
// Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
// after SPARK-5581 is fixed.
- BlockObjectWriter writer;
+ DiskBlockObjectWriter writer;
// Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
// be an API to directly transfer bytes from managed memory to the disk writer, we buffer
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index b8d6665980..71eed29563 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -26,7 +26,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManager;
-import org.apache.spark.storage.BlockObjectWriter;
+import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.TempLocalBlockId;
import org.apache.spark.unsafe.PlatformDependent;
@@ -47,7 +47,7 @@ final class UnsafeSorterSpillWriter {
private final File file;
private final BlockId blockId;
private final int numRecordsToWrite;
- private BlockObjectWriter writer;
+ private DiskBlockObjectWriter writer;
private int numRecordsSpilled = 0;
public UnsafeSorterSpillWriter(
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index 6c3b3080d2..f6a96d81e7 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVecto
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
private[spark] trait ShuffleWriterGroup {
- val writers: Array[BlockObjectWriter]
+ val writers: Array[DiskBlockObjectWriter]
/** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
def releaseWriters(success: Boolean)
@@ -113,15 +113,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
val openStartTime = System.nanoTime
val serializerInstance = serializer.newInstance()
- val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
+ val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) {
fileGroup = getUnusedFileGroup()
- Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize,
writeMetrics)
}
} else {
- Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
val blockFile = blockManager.diskBlockManager.getFile(blockId)
// Because of previous failures, the shuffle file may already exist on this machine.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index d9c63b6e7b..fae69551e7 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -114,7 +114,7 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
}
private[spark] object IndexShuffleBlockResolver {
- // No-op reduce ID used in interactions with disk store and BlockObjectWriter.
+ // No-op reduce ID used in interactions with disk store and DiskBlockObjectWriter.
// The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort
// shuffle outputs for several reduces are glommed into a single file.
// TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index eb87cee159..41df70c602 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -22,7 +22,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._
-import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.storage.DiskBlockObjectWriter
private[spark] class HashShuffleWriter[K, V](
shuffleBlockResolver: FileShuffleBlockResolver,
@@ -102,7 +102,7 @@ private[spark] class HashShuffleWriter[K, V](
private def commitWritesAndBuildStatus(): MapStatus = {
// Commit the writes. Get the size of each bucket block (total block size).
- val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter =>
+ val sizes: Array[Long] = shuffle.writers.map { writer: DiskBlockObjectWriter =>
writer.commitAndClose()
writer.fileSegment().length
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 1beafa1771..86493673d9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -648,7 +648,7 @@ private[spark] class BlockManager(
file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
- writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = {
+ writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 7eeabd1e04..49d9154f95 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -26,66 +26,25 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.util.Utils
/**
- * An interface for writing JVM objects to some underlying storage. This interface allows
- * appending data to an existing block, and can guarantee atomicity in the case of faults
- * as it allows the caller to revert partial writes.
+ * A class for writing JVM objects directly to a file on disk. This class allows data to be appended
+ * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to
+ * revert partial writes.
*
- * This interface does not support concurrent writes. Also, once the writer has
- * been opened, it cannot be reopened again.
- */
-private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream {
-
- def open(): BlockObjectWriter
-
- def close()
-
- def isOpen: Boolean
-
- /**
- * Flush the partial writes and commit them as a single atomic block.
- */
- def commitAndClose(): Unit
-
- /**
- * Reverts writes that haven't been flushed yet. Callers should invoke this function
- * when there are runtime exceptions. This method will not throw, though it may be
- * unsuccessful in truncating written data.
- */
- def revertPartialWritesAndClose()
-
- /**
- * Writes a key-value pair.
- */
- def write(key: Any, value: Any)
-
- /**
- * Notify the writer that a record worth of bytes has been written with OutputStream#write.
- */
- def recordWritten()
-
- /**
- * Returns the file segment of committed data that this Writer has written.
- * This is only valid after commitAndClose() has been called.
- */
- def fileSegment(): FileSegment
-}
-
-/**
- * BlockObjectWriter which writes directly to a file on disk. Appends to the given file.
+ * This class does not support concurrent writes. Also, once the writer has been opened it cannot be
+ * reopened again.
*/
private[spark] class DiskBlockObjectWriter(
- blockId: BlockId,
+ val blockId: BlockId,
file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
compressStream: OutputStream => OutputStream,
syncWrites: Boolean,
- // These write metrics concurrently shared with other active BlockObjectWriter's who
+ // These write metrics concurrently shared with other active DiskBlockObjectWriters who
// are themselves performing writes. All updates must be relative.
writeMetrics: ShuffleWriteMetrics)
- extends BlockObjectWriter(blockId)
- with Logging
-{
+ extends OutputStream
+ with Logging {
/** The file channel, used for repositioning / truncating the file. */
private var channel: FileChannel = null
@@ -122,7 +81,7 @@ private[spark] class DiskBlockObjectWriter(
*/
private var numRecordsWritten = 0
- override def open(): BlockObjectWriter = {
+ def open(): DiskBlockObjectWriter = {
if (hasBeenClosed) {
throw new IllegalStateException("Writer already closed. Cannot be reopened.")
}
@@ -159,9 +118,12 @@ private[spark] class DiskBlockObjectWriter(
}
}
- override def isOpen: Boolean = objOut != null
+ def isOpen: Boolean = objOut != null
- override def commitAndClose(): Unit = {
+ /**
+ * Flush the partial writes and commit them as a single atomic block.
+ */
+ def commitAndClose(): Unit = {
if (initialized) {
// NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the
// serializer stream and the lower level stream.
@@ -177,9 +139,15 @@ private[spark] class DiskBlockObjectWriter(
commitAndCloseHasBeenCalled = true
}
- // Discard current writes. We do this by flushing the outstanding writes and then
- // truncating the file to its initial position.
- override def revertPartialWritesAndClose() {
+
+ /**
+ * Reverts writes that haven't been flushed yet. Callers should invoke this function
+ * when there are runtime exceptions. This method will not throw, though it may be
+ * unsuccessful in truncating written data.
+ */
+ def revertPartialWritesAndClose() {
+ // Discard current writes. We do this by flushing the outstanding writes and then
+ // truncating the file to its initial position.
try {
if (initialized) {
writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
@@ -201,7 +169,10 @@ private[spark] class DiskBlockObjectWriter(
}
}
- override def write(key: Any, value: Any) {
+ /**
+ * Writes a key-value pair.
+ */
+ def write(key: Any, value: Any) {
if (!initialized) {
open()
}
@@ -221,7 +192,10 @@ private[spark] class DiskBlockObjectWriter(
bs.write(kvBytes, offs, len)
}
- override def recordWritten(): Unit = {
+ /**
+ * Notify the writer that a record worth of bytes has been written with OutputStream#write.
+ */
+ def recordWritten(): Unit = {
numRecordsWritten += 1
writeMetrics.incShuffleRecordsWritten(1)
@@ -230,7 +204,11 @@ private[spark] class DiskBlockObjectWriter(
}
}
- override def fileSegment(): FileSegment = {
+ /**
+ * Returns the file segment of committed data that this Writer has written.
+ * This is only valid after commitAndClose() has been called.
+ */
+ def fileSegment(): FileSegment = {
if (!commitAndCloseHasBeenCalled) {
throw new IllegalStateException(
"fileSegment() is only valid after commitAndClose() has been called")
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
index 516aaa44d0..ae60f3b0cb 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
@@ -37,7 +37,7 @@ private[spark] class ChainedBuffer(chunkSize: Int) {
private var _size: Long = 0
/**
- * Feed bytes from this buffer into a BlockObjectWriter.
+ * Feed bytes from this buffer into a DiskBlockObjectWriter.
*
* @param pos Offset in the buffer to read from.
* @param os OutputStream to read into.
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 757dec66c2..ba7ec834d6 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -30,7 +30,7 @@ import org.apache.spark._
import org.apache.spark.serializer._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter}
-import org.apache.spark.storage.{BlockId, BlockObjectWriter}
+import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
/**
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -250,7 +250,7 @@ private[spark] class ExternalSorter[K, V, C](
// These variables are reset after each flush
var objectsWritten: Long = 0
var spillMetrics: ShuffleWriteMetrics = null
- var writer: BlockObjectWriter = null
+ var writer: DiskBlockObjectWriter = null
def openWriter(): Unit = {
assert (writer == null && spillMetrics == null)
spillMetrics = new ShuffleWriteMetrics
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
index 04bb7fc78c..f5844d5353 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
@@ -19,7 +19,6 @@ package org.apache.spark.util.collection
import java.util.Comparator
-import org.apache.spark.storage.BlockObjectWriter
import org.apache.spark.util.collection.WritablePartitionedPairCollection._
/**
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
index ae9a48729e..87a786b02d 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
@@ -21,9 +21,8 @@ import java.io.InputStream
import java.nio.IntBuffer
import java.util.Comparator
-import org.apache.spark.SparkEnv
import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance}
-import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.storage.DiskBlockObjectWriter
import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
/**
@@ -136,7 +135,7 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
// current position in the meta buffer in ints
var pos = 0
- def writeNext(writer: BlockObjectWriter): Unit = {
+ def writeNext(writer: DiskBlockObjectWriter): Unit = {
val keyStart = getKeyStartPos(metaBuffer, pos)
val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN)
pos += RECORD_SIZE
diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
index 7bc5989865..38848e9018 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -19,7 +19,7 @@ package org.apache.spark.util.collection
import java.util.Comparator
-import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.storage.DiskBlockObjectWriter
/**
* A common interface for size-tracking collections of key-value pairs that
@@ -51,7 +51,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
new WritablePartitionedIterator {
private[this] var cur = if (it.hasNext) it.next() else null
- def writeNext(writer: BlockObjectWriter): Unit = {
+ def writeNext(writer: DiskBlockObjectWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (it.hasNext) it.next() else null
}
@@ -91,11 +91,11 @@ private[spark] object WritablePartitionedPairCollection {
}
/**
- * Iterator that writes elements to a BlockObjectWriter instead of returning them. Each element
+ * Iterator that writes elements to a DiskBlockObjectWriter instead of returning them. Each element
* has an associated partition.
*/
private[spark] trait WritablePartitionedIterator {
- def writeNext(writer: BlockObjectWriter): Unit
+ def writeNext(writer: DiskBlockObjectWriter): Unit
def hasNext(): Boolean
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 542f8f4512..cc7342f1ec 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -68,8 +68,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
any[SerializerInstance],
anyInt(),
any[ShuffleWriteMetrics]
- )).thenAnswer(new Answer[BlockObjectWriter] {
- override def answer(invocation: InvocationOnMock): BlockObjectWriter = {
+ )).thenAnswer(new Answer[DiskBlockObjectWriter] {
+ override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = {
val args = invocation.getArguments
new DiskBlockObjectWriter(
args(0).asInstanceOf[BlockId],
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
index 7bdea724fe..66af6e1a79 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.Utils
-class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
+class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
var tempDir: File = _
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
index 6d2459d48d..3b67f62064 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
@@ -17,15 +17,20 @@
package org.apache.spark.util.collection
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import com.google.common.io.ByteStreams
+import org.mockito.Matchers.any
+import org.mockito.Mockito._
+import org.mockito.Mockito.RETURNS_SMART_NULLS
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
import org.scalatest.Matchers._
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.KryoSerializer
-import org.apache.spark.storage.{FileSegment, BlockObjectWriter}
+import org.apache.spark.storage.DiskBlockObjectWriter
class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
test("OrderedInputStream single record") {
@@ -79,13 +84,13 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
val struct = SomeStruct("something", 5)
buffer.insert(4, 10, struct)
val it = buffer.destructiveSortedWritablePartitionedIterator(None)
- val writer = new SimpleBlockObjectWriter
+ val (writer, baos) = createMockWriter()
assert(it.hasNext)
it.nextPartition should be (4)
it.writeNext(writer)
assert(!it.hasNext)
- val stream = serializerInstance.deserializeStream(writer.getInputStream)
+ val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
stream.readObject[AnyRef]() should be (10)
stream.readObject[AnyRef]() should be (struct)
}
@@ -101,7 +106,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
buffer.insert(5, 3, struct3)
val it = buffer.destructiveSortedWritablePartitionedIterator(None)
- val writer = new SimpleBlockObjectWriter
+ val (writer, baos) = createMockWriter()
assert(it.hasNext)
it.nextPartition should be (4)
it.writeNext(writer)
@@ -113,7 +118,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
it.writeNext(writer)
assert(!it.hasNext)
- val stream = serializerInstance.deserializeStream(writer.getInputStream)
+ val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
val iter = stream.asIterator
iter.next() should be (2)
iter.next() should be (struct2)
@@ -123,26 +128,21 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
iter.next() should be (struct1)
assert(!iter.hasNext)
}
-}
-
-case class SomeStruct(val str: String, val num: Int)
-
-class SimpleBlockObjectWriter extends BlockObjectWriter(null) {
- val baos = new ByteArrayOutputStream()
- override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
- baos.write(bytes, offs, len)
+ def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = {
+ val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS)
+ val baos = new ByteArrayOutputStream()
+ when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocationOnMock: InvocationOnMock): Unit = {
+ val args = invocationOnMock.getArguments
+ val bytes = args(0).asInstanceOf[Array[Byte]]
+ val offset = args(1).asInstanceOf[Int]
+ val length = args(2).asInstanceOf[Int]
+ baos.write(bytes, offset, length)
+ }
+ })
+ (writer, baos)
}
-
- def getInputStream(): InputStream = new ByteArrayInputStream(baos.toByteArray)
-
- override def open(): BlockObjectWriter = this
- override def close(): Unit = { }
- override def isOpen: Boolean = true
- override def commitAndClose(): Unit = { }
- override def revertPartialWritesAndClose(): Unit = { }
- override def fileSegment(): FileSegment = null
- override def write(key: Any, value: Any): Unit = { }
- override def recordWritten(): Unit = { }
- override def write(b: Int): Unit = { }
}
+
+case class SomeStruct(str: String, num: Int)