aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-09-17 11:46:15 -0700
committerJosh Rosen <joshrosen@databricks.com>2016-09-17 11:46:15 -0700
commit8faa5217b44e8d52eab7eb2d53d0652abaaf43cd (patch)
treedaf1a90737024c0dccd567f66a8b13ee0f2d3c1a /core/src/main
parent86c2d393a56bf1e5114bc5a781253c0460efb8af (diff)
downloadspark-8faa5217b44e8d52eab7eb2d53d0652abaaf43cd.tar.gz
spark-8faa5217b44e8d52eab7eb2d53d0652abaaf43cd.tar.bz2
spark-8faa5217b44e8d52eab7eb2d53d0652abaaf43cd.zip
[SPARK-17491] Close serialization stream to fix wrong answer bug in putIteratorAsBytes()
## What changes were proposed in this pull request? `MemoryStore.putIteratorAsBytes()` may silently lose values when used with `KryoSerializer` because it does not properly close the serialization stream before attempting to deserialize the already-serialized values, which may cause values buffered in Kryo's internal buffers to not be read. This is the root cause behind a user-reported "wrong answer" bug in PySpark caching reported by bennoleslie on the Spark user mailing list in a thread titled "pyspark persist MEMORY_ONLY vs MEMORY_AND_DISK". Due to Spark 2.0's automatic use of KryoSerializer for "safe" types (such as byte arrays, primitives, etc.) this misuse of serializers manifested itself as silent data corruption rather than a StreamCorrupted error (which you might get from JavaSerializer). The minimal fix, implemented here, is to close the serialization stream before attempting to deserialize written values. In addition, this patch adds several additional assertions / precondition checks to prevent misuse of `PartiallySerializedBlock` and `ChunkedByteBufferOutputStream`. ## How was this patch tested? The original bug was masked by an invalid assert in the memory store test cases: the old assert compared two results record-by-record with `zip` but didn't first check that the lengths of the two collections were equal, causing missing records to go unnoticed. The updated test case reproduced this bug. In addition, I added a new `PartiallySerializedBlockSuite` to unit test that component. Author: Josh Rosen <joshrosen@databricks.com> Closes #15043 from JoshRosen/partially-serialized-block-values-iterator-bugfix.
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala89
-rw-r--r--core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala12
4 files changed, 104 insertions, 25 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 35c4dafe9c..1ed36bf069 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -230,6 +230,7 @@ private[spark] object Task {
dataOut.flush()
val taskBytes = serializer.serialize(task)
Utils.writeByteBuffer(taskBytes, out)
+ out.close()
out.toByteBuffer
}
diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index ec1b0f7149..205d469f48 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -33,7 +33,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.serializer.{SerializationStream, SerializerManager}
import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
import org.apache.spark.unsafe.Platform
-import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
+import org.apache.spark.util.{SizeEstimator, Utils}
import org.apache.spark.util.collection.SizeTrackingVector
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
@@ -277,6 +277,7 @@ private[spark] class MemoryStore(
"released too much unroll memory")
Left(new PartiallyUnrolledIterator(
this,
+ MemoryMode.ON_HEAP,
unrollMemoryUsedByThisBlock,
unrolled = arrayValues.toIterator,
rest = Iterator.empty))
@@ -285,7 +286,11 @@ private[spark] class MemoryStore(
// We ran out of space while unrolling the values for this block
logUnrollFailureMessage(blockId, vector.estimateSize())
Left(new PartiallyUnrolledIterator(
- this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values))
+ this,
+ MemoryMode.ON_HEAP,
+ unrollMemoryUsedByThisBlock,
+ unrolled = vector.iterator,
+ rest = values))
}
}
@@ -394,7 +399,7 @@ private[spark] class MemoryStore(
redirectableStream,
unrollMemoryUsedByThisBlock,
memoryMode,
- bbos.toChunkedByteBuffer,
+ bbos,
values,
classTag))
}
@@ -655,6 +660,7 @@ private[spark] class MemoryStore(
* The result of a failed [[MemoryStore.putIteratorAsValues()]] call.
*
* @param memoryStore the memoryStore, used for freeing memory.
+ * @param memoryMode the memory mode (on- or off-heap).
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param unrolled an iterator for the partially-unrolled values.
* @param rest the rest of the original iterator passed to
@@ -662,13 +668,14 @@ private[spark] class MemoryStore(
*/
private[storage] class PartiallyUnrolledIterator[T](
memoryStore: MemoryStore,
+ memoryMode: MemoryMode,
unrollMemory: Long,
private[this] var unrolled: Iterator[T],
rest: Iterator[T])
extends Iterator[T] {
private def releaseUnrollMemory(): Unit = {
- memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
// SPARK-17503: Garbage collects the unrolling memory before the life end of
// PartiallyUnrolledIterator.
unrolled = null
@@ -706,7 +713,7 @@ private[storage] class PartiallyUnrolledIterator[T](
/**
* A wrapper which allows an open [[OutputStream]] to be redirected to a different sink.
*/
-private class RedirectableOutputStream extends OutputStream {
+private[storage] class RedirectableOutputStream extends OutputStream {
private[this] var os: OutputStream = _
def setOutputStream(s: OutputStream): Unit = { os = s }
override def write(b: Int): Unit = os.write(b)
@@ -726,7 +733,8 @@ private class RedirectableOutputStream extends OutputStream {
* @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param memoryMode whether the unroll memory is on- or off-heap
- * @param unrolled a byte buffer containing the partially-serialized values.
+ * @param bbos byte buffer output stream containing the partially-serialized values.
+ * [[redirectableOutputStream]] initially points to this output stream.
* @param rest the rest of the original iterator passed to
* [[MemoryStore.putIteratorAsValues()]].
* @param classTag the [[ClassTag]] for the block.
@@ -735,14 +743,19 @@ private[storage] class PartiallySerializedBlock[T](
memoryStore: MemoryStore,
serializerManager: SerializerManager,
blockId: BlockId,
- serializationStream: SerializationStream,
- redirectableOutputStream: RedirectableOutputStream,
- unrollMemory: Long,
+ private val serializationStream: SerializationStream,
+ private val redirectableOutputStream: RedirectableOutputStream,
+ val unrollMemory: Long,
memoryMode: MemoryMode,
- unrolled: ChunkedByteBuffer,
+ bbos: ChunkedByteBufferOutputStream,
rest: Iterator[T],
classTag: ClassTag[T]) {
+ private lazy val unrolledBuffer: ChunkedByteBuffer = {
+ bbos.close()
+ bbos.toChunkedByteBuffer
+ }
+
// If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of
// this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task
// completion listener here in order to ensure that `unrolled.dispose()` is called at least once.
@@ -751,7 +764,23 @@ private[storage] class PartiallySerializedBlock[T](
taskContext.addTaskCompletionListener { _ =>
// When a task completes, its unroll memory will automatically be freed. Thus we do not call
// releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing.
- unrolled.dispose()
+ unrolledBuffer.dispose()
+ }
+ }
+
+ // Exposed for testing
+ private[storage] def getUnrolledChunkedByteBuffer: ChunkedByteBuffer = unrolledBuffer
+
+ private[this] var discarded = false
+ private[this] var consumed = false
+
+ private def verifyNotConsumedAndNotDiscarded(): Unit = {
+ if (consumed) {
+ throw new IllegalStateException(
+ "Can only call one of finishWritingToStream() or valuesIterator() and can only call once.")
+ }
+ if (discarded) {
+ throw new IllegalStateException("Cannot call methods on a discarded PartiallySerializedBlock")
}
}
@@ -759,15 +788,18 @@ private[storage] class PartiallySerializedBlock[T](
* Called to dispose of this block and free its memory.
*/
def discard(): Unit = {
- try {
- // We want to close the output stream in order to free any resources associated with the
- // serializer itself (such as Kryo's internal buffers). close() might cause data to be
- // written, so redirect the output stream to discard that data.
- redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
- serializationStream.close()
- } finally {
- unrolled.dispose()
- memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+ if (!discarded) {
+ try {
+ // We want to close the output stream in order to free any resources associated with the
+ // serializer itself (such as Kryo's internal buffers). close() might cause data to be
+ // written, so redirect the output stream to discard that data.
+ redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
+ serializationStream.close()
+ } finally {
+ discarded = true
+ unrolledBuffer.dispose()
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+ }
}
}
@@ -776,8 +808,10 @@ private[storage] class PartiallySerializedBlock[T](
* and then serializing the values from the original input iterator.
*/
def finishWritingToStream(os: OutputStream): Unit = {
+ verifyNotConsumedAndNotDiscarded()
+ consumed = true
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
- ByteStreams.copy(unrolled.toInputStream(dispose = true), os)
+ ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), os)
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
redirectableOutputStream.setOutputStream(os)
while (rest.hasNext) {
@@ -794,13 +828,22 @@ private[storage] class PartiallySerializedBlock[T](
* `close()` on it to free its resources.
*/
def valuesIterator: PartiallyUnrolledIterator[T] = {
+ verifyNotConsumedAndNotDiscarded()
+ consumed = true
+ // Close the serialization stream so that the serializer's internal buffers are freed and any
+ // "end-of-stream" markers can be written out so that `unrolled` is a valid serialized stream.
+ serializationStream.close()
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
val unrolledIter = serializerManager.dataDeserializeStream(
- blockId, unrolled.toInputStream(dispose = true))(classTag)
+ blockId, unrolledBuffer.toInputStream(dispose = true))(classTag)
+ // The unroll memory will be freed once `unrolledIter` is fully consumed in
+ // PartiallyUnrolledIterator. If the iterator is not consumed by the end of the task then any
+ // extra unroll memory will automatically be freed by a `finally` block in `Task`.
new PartiallyUnrolledIterator(
memoryStore,
+ memoryMode,
unrollMemory,
- unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()),
+ unrolled = unrolledIter,
rest = rest)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
index 09e7579ae9..9077b86f9b 100644
--- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
@@ -29,7 +29,32 @@ private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutp
def getCount(): Int = count
+ private[this] var closed: Boolean = false
+
+ override def write(b: Int): Unit = {
+ require(!closed, "cannot write to a closed ByteBufferOutputStream")
+ super.write(b)
+ }
+
+ override def write(b: Array[Byte], off: Int, len: Int): Unit = {
+ require(!closed, "cannot write to a closed ByteBufferOutputStream")
+ super.write(b, off, len)
+ }
+
+ override def reset(): Unit = {
+ require(!closed, "cannot reset a closed ByteBufferOutputStream")
+ super.reset()
+ }
+
+ override def close(): Unit = {
+ if (!closed) {
+ super.close()
+ closed = true
+ }
+ }
+
def toByteBuffer: ByteBuffer = {
- return ByteBuffer.wrap(buf, 0, count)
+ require(closed, "can only call toByteBuffer() after ByteBufferOutputStream has been closed")
+ ByteBuffer.wrap(buf, 0, count)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
index 67b50d1e70..a625b32895 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
@@ -49,10 +49,19 @@ private[spark] class ChunkedByteBufferOutputStream(
*/
private[this] var position = chunkSize
private[this] var _size = 0
+ private[this] var closed: Boolean = false
def size: Long = _size
+ override def close(): Unit = {
+ if (!closed) {
+ super.close()
+ closed = true
+ }
+ }
+
override def write(b: Int): Unit = {
+ require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
allocateNewChunkIfNeeded()
chunks(lastChunkIndex).put(b.toByte)
position += 1
@@ -60,6 +69,7 @@ private[spark] class ChunkedByteBufferOutputStream(
}
override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
+ require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
var written = 0
while (written < len) {
allocateNewChunkIfNeeded()
@@ -73,7 +83,6 @@ private[spark] class ChunkedByteBufferOutputStream(
@inline
private def allocateNewChunkIfNeeded(): Unit = {
- require(!toChunkedByteBufferWasCalled, "cannot write after toChunkedByteBuffer() is called")
if (position == chunkSize) {
chunks += allocator(chunkSize)
lastChunkIndex += 1
@@ -82,6 +91,7 @@ private[spark] class ChunkedByteBufferOutputStream(
}
def toChunkedByteBuffer: ChunkedByteBuffer = {
+ require(closed, "cannot call toChunkedByteBuffer() unless close() has been called")
require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once")
toChunkedByteBufferWasCalled = true
if (lastChunkIndex == -1) {