aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala89
1 files changed, 66 insertions, 23 deletions
diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index ec1b0f7149..205d469f48 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -33,7 +33,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.serializer.{SerializationStream, SerializerManager}
import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
import org.apache.spark.unsafe.Platform
-import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
+import org.apache.spark.util.{SizeEstimator, Utils}
import org.apache.spark.util.collection.SizeTrackingVector
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
@@ -277,6 +277,7 @@ private[spark] class MemoryStore(
"released too much unroll memory")
Left(new PartiallyUnrolledIterator(
this,
+ MemoryMode.ON_HEAP,
unrollMemoryUsedByThisBlock,
unrolled = arrayValues.toIterator,
rest = Iterator.empty))
@@ -285,7 +286,11 @@ private[spark] class MemoryStore(
// We ran out of space while unrolling the values for this block
logUnrollFailureMessage(blockId, vector.estimateSize())
Left(new PartiallyUnrolledIterator(
- this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values))
+ this,
+ MemoryMode.ON_HEAP,
+ unrollMemoryUsedByThisBlock,
+ unrolled = vector.iterator,
+ rest = values))
}
}
@@ -394,7 +399,7 @@ private[spark] class MemoryStore(
redirectableStream,
unrollMemoryUsedByThisBlock,
memoryMode,
- bbos.toChunkedByteBuffer,
+ bbos,
values,
classTag))
}
@@ -655,6 +660,7 @@ private[spark] class MemoryStore(
* The result of a failed [[MemoryStore.putIteratorAsValues()]] call.
*
* @param memoryStore the memoryStore, used for freeing memory.
+ * @param memoryMode the memory mode (on- or off-heap).
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param unrolled an iterator for the partially-unrolled values.
* @param rest the rest of the original iterator passed to
@@ -662,13 +668,14 @@ private[spark] class MemoryStore(
*/
private[storage] class PartiallyUnrolledIterator[T](
memoryStore: MemoryStore,
+ memoryMode: MemoryMode,
unrollMemory: Long,
private[this] var unrolled: Iterator[T],
rest: Iterator[T])
extends Iterator[T] {
private def releaseUnrollMemory(): Unit = {
- memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
// SPARK-17503: Garbage collects the unrolling memory before the life end of
// PartiallyUnrolledIterator.
unrolled = null
@@ -706,7 +713,7 @@ private[storage] class PartiallyUnrolledIterator[T](
/**
* A wrapper which allows an open [[OutputStream]] to be redirected to a different sink.
*/
-private class RedirectableOutputStream extends OutputStream {
+private[storage] class RedirectableOutputStream extends OutputStream {
private[this] var os: OutputStream = _
def setOutputStream(s: OutputStream): Unit = { os = s }
override def write(b: Int): Unit = os.write(b)
@@ -726,7 +733,8 @@ private class RedirectableOutputStream extends OutputStream {
* @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param memoryMode whether the unroll memory is on- or off-heap
- * @param unrolled a byte buffer containing the partially-serialized values.
+ * @param bbos byte buffer output stream containing the partially-serialized values.
+ * [[redirectableOutputStream]] initially points to this output stream.
* @param rest the rest of the original iterator passed to
* [[MemoryStore.putIteratorAsValues()]].
* @param classTag the [[ClassTag]] for the block.
@@ -735,14 +743,19 @@ private[storage] class PartiallySerializedBlock[T](
memoryStore: MemoryStore,
serializerManager: SerializerManager,
blockId: BlockId,
- serializationStream: SerializationStream,
- redirectableOutputStream: RedirectableOutputStream,
- unrollMemory: Long,
+ private val serializationStream: SerializationStream,
+ private val redirectableOutputStream: RedirectableOutputStream,
+ val unrollMemory: Long,
memoryMode: MemoryMode,
- unrolled: ChunkedByteBuffer,
+ bbos: ChunkedByteBufferOutputStream,
rest: Iterator[T],
classTag: ClassTag[T]) {
+ private lazy val unrolledBuffer: ChunkedByteBuffer = {
+ bbos.close()
+ bbos.toChunkedByteBuffer
+ }
+
// If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of
// this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task
// completion listener here in order to ensure that `unrolled.dispose()` is called at least once.
@@ -751,7 +764,23 @@ private[storage] class PartiallySerializedBlock[T](
taskContext.addTaskCompletionListener { _ =>
// When a task completes, its unroll memory will automatically be freed. Thus we do not call
// releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing.
- unrolled.dispose()
+ unrolledBuffer.dispose()
+ }
+ }
+
+ // Exposed for testing
+ private[storage] def getUnrolledChunkedByteBuffer: ChunkedByteBuffer = unrolledBuffer
+
+ private[this] var discarded = false
+ private[this] var consumed = false
+
+ private def verifyNotConsumedAndNotDiscarded(): Unit = {
+ if (consumed) {
+ throw new IllegalStateException(
+ "Can only call one of finishWritingToStream() or valuesIterator() and can only call once.")
+ }
+ if (discarded) {
+ throw new IllegalStateException("Cannot call methods on a discarded PartiallySerializedBlock")
}
}
@@ -759,15 +788,18 @@ private[storage] class PartiallySerializedBlock[T](
* Called to dispose of this block and free its memory.
*/
def discard(): Unit = {
- try {
- // We want to close the output stream in order to free any resources associated with the
- // serializer itself (such as Kryo's internal buffers). close() might cause data to be
- // written, so redirect the output stream to discard that data.
- redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
- serializationStream.close()
- } finally {
- unrolled.dispose()
- memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+ if (!discarded) {
+ try {
+ // We want to close the output stream in order to free any resources associated with the
+ // serializer itself (such as Kryo's internal buffers). close() might cause data to be
+ // written, so redirect the output stream to discard that data.
+ redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
+ serializationStream.close()
+ } finally {
+ discarded = true
+ unrolledBuffer.dispose()
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+ }
}
}
@@ -776,8 +808,10 @@ private[storage] class PartiallySerializedBlock[T](
* and then serializing the values from the original input iterator.
*/
def finishWritingToStream(os: OutputStream): Unit = {
+ verifyNotConsumedAndNotDiscarded()
+ consumed = true
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
- ByteStreams.copy(unrolled.toInputStream(dispose = true), os)
+ ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), os)
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
redirectableOutputStream.setOutputStream(os)
while (rest.hasNext) {
@@ -794,13 +828,22 @@ private[storage] class PartiallySerializedBlock[T](
* `close()` on it to free its resources.
*/
def valuesIterator: PartiallyUnrolledIterator[T] = {
+ verifyNotConsumedAndNotDiscarded()
+ consumed = true
+ // Close the serialization stream so that the serializer's internal buffers are freed and any
+ // "end-of-stream" markers can be written out so that `unrolled` is a valid serialized stream.
+ serializationStream.close()
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
val unrolledIter = serializerManager.dataDeserializeStream(
- blockId, unrolled.toInputStream(dispose = true))(classTag)
+ blockId, unrolledBuffer.toInputStream(dispose = true))(classTag)
+ // The unroll memory will be freed once `unrolledIter` is fully consumed in
+ // PartiallyUnrolledIterator. If the iterator is not consumed by the end of the task then any
+ // extra unroll memory will automatically be freed by a `finally` block in `Task`.
new PartiallyUnrolledIterator(
memoryStore,
+ memoryMode,
unrollMemory,
- unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()),
+ unrolled = unrolledIter,
rest = rest)
}
}