aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-09-17 11:46:15 -0700
committerJosh Rosen <joshrosen@databricks.com>2016-09-17 11:46:15 -0700
commit8faa5217b44e8d52eab7eb2d53d0652abaaf43cd (patch)
treedaf1a90737024c0dccd567f66a8b13ee0f2d3c1a
parent86c2d393a56bf1e5114bc5a781253c0460efb8af (diff)
downloadspark-8faa5217b44e8d52eab7eb2d53d0652abaaf43cd.tar.gz
spark-8faa5217b44e8d52eab7eb2d53d0652abaaf43cd.tar.bz2
spark-8faa5217b44e8d52eab7eb2d53d0652abaaf43cd.zip
[SPARK-17491] Close serialization stream to fix wrong answer bug in putIteratorAsBytes()
## What changes were proposed in this pull request? `MemoryStore.putIteratorAsBytes()` may silently lose values when used with `KryoSerializer` because it does not properly close the serialization stream before attempting to deserialize the already-serialized values, which may cause values buffered in Kryo's internal buffers to not be read. This is the root cause behind a user-reported "wrong answer" bug in PySpark caching reported by bennoleslie on the Spark user mailing list in a thread titled "pyspark persist MEMORY_ONLY vs MEMORY_AND_DISK". Due to Spark 2.0's automatic use of KryoSerializer for "safe" types (such as byte arrays, primitives, etc.) this misuse of serializers manifested itself as silent data corruption rather than a StreamCorrupted error (which you might get from JavaSerializer). The minimal fix, implemented here, is to close the serialization stream before attempting to deserialize written values. In addition, this patch adds several additional assertions / precondition checks to prevent misuse of `PartiallySerializedBlock` and `ChunkedByteBufferOutputStream`. ## How was this patch tested? The original bug was masked by an invalid assert in the memory store test cases: the old assert compared two results record-by-record with `zip` but didn't first check that the lengths of the two collections were equal, causing missing records to go unnoticed. The updated test case reproduced this bug. In addition, I added a new `PartiallySerializedBlockSuite` to unit test that component. Author: Josh Rosen <joshrosen@databricks.com> Closes #15043 from JoshRosen/partially-serialized-block-values-iterator-bugfix.
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala89
-rw-r--r--core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala12
-rw-r--r--core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala34
-rw-r--r--core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala215
-rw-r--r--core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala8
8 files changed, 344 insertions, 44 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 35c4dafe9c..1ed36bf069 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -230,6 +230,7 @@ private[spark] object Task {
dataOut.flush()
val taskBytes = serializer.serialize(task)
Utils.writeByteBuffer(taskBytes, out)
+ out.close()
out.toByteBuffer
}
diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index ec1b0f7149..205d469f48 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -33,7 +33,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.serializer.{SerializationStream, SerializerManager}
import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
import org.apache.spark.unsafe.Platform
-import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
+import org.apache.spark.util.{SizeEstimator, Utils}
import org.apache.spark.util.collection.SizeTrackingVector
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
@@ -277,6 +277,7 @@ private[spark] class MemoryStore(
"released too much unroll memory")
Left(new PartiallyUnrolledIterator(
this,
+ MemoryMode.ON_HEAP,
unrollMemoryUsedByThisBlock,
unrolled = arrayValues.toIterator,
rest = Iterator.empty))
@@ -285,7 +286,11 @@ private[spark] class MemoryStore(
// We ran out of space while unrolling the values for this block
logUnrollFailureMessage(blockId, vector.estimateSize())
Left(new PartiallyUnrolledIterator(
- this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values))
+ this,
+ MemoryMode.ON_HEAP,
+ unrollMemoryUsedByThisBlock,
+ unrolled = vector.iterator,
+ rest = values))
}
}
@@ -394,7 +399,7 @@ private[spark] class MemoryStore(
redirectableStream,
unrollMemoryUsedByThisBlock,
memoryMode,
- bbos.toChunkedByteBuffer,
+ bbos,
values,
classTag))
}
@@ -655,6 +660,7 @@ private[spark] class MemoryStore(
* The result of a failed [[MemoryStore.putIteratorAsValues()]] call.
*
* @param memoryStore the memoryStore, used for freeing memory.
+ * @param memoryMode the memory mode (on- or off-heap).
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param unrolled an iterator for the partially-unrolled values.
* @param rest the rest of the original iterator passed to
@@ -662,13 +668,14 @@ private[spark] class MemoryStore(
*/
private[storage] class PartiallyUnrolledIterator[T](
memoryStore: MemoryStore,
+ memoryMode: MemoryMode,
unrollMemory: Long,
private[this] var unrolled: Iterator[T],
rest: Iterator[T])
extends Iterator[T] {
private def releaseUnrollMemory(): Unit = {
- memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
// SPARK-17503: Garbage collects the unrolling memory before the life end of
// PartiallyUnrolledIterator.
unrolled = null
@@ -706,7 +713,7 @@ private[storage] class PartiallyUnrolledIterator[T](
/**
* A wrapper which allows an open [[OutputStream]] to be redirected to a different sink.
*/
-private class RedirectableOutputStream extends OutputStream {
+private[storage] class RedirectableOutputStream extends OutputStream {
private[this] var os: OutputStream = _
def setOutputStream(s: OutputStream): Unit = { os = s }
override def write(b: Int): Unit = os.write(b)
@@ -726,7 +733,8 @@ private class RedirectableOutputStream extends OutputStream {
* @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param memoryMode whether the unroll memory is on- or off-heap
- * @param unrolled a byte buffer containing the partially-serialized values.
+ * @param bbos byte buffer output stream containing the partially-serialized values.
+ * [[redirectableOutputStream]] initially points to this output stream.
* @param rest the rest of the original iterator passed to
* [[MemoryStore.putIteratorAsValues()]].
* @param classTag the [[ClassTag]] for the block.
@@ -735,14 +743,19 @@ private[storage] class PartiallySerializedBlock[T](
memoryStore: MemoryStore,
serializerManager: SerializerManager,
blockId: BlockId,
- serializationStream: SerializationStream,
- redirectableOutputStream: RedirectableOutputStream,
- unrollMemory: Long,
+ private val serializationStream: SerializationStream,
+ private val redirectableOutputStream: RedirectableOutputStream,
+ val unrollMemory: Long,
memoryMode: MemoryMode,
- unrolled: ChunkedByteBuffer,
+ bbos: ChunkedByteBufferOutputStream,
rest: Iterator[T],
classTag: ClassTag[T]) {
+ private lazy val unrolledBuffer: ChunkedByteBuffer = {
+ bbos.close()
+ bbos.toChunkedByteBuffer
+ }
+
// If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of
// this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task
// completion listener here in order to ensure that `unrolled.dispose()` is called at least once.
@@ -751,7 +764,23 @@ private[storage] class PartiallySerializedBlock[T](
taskContext.addTaskCompletionListener { _ =>
// When a task completes, its unroll memory will automatically be freed. Thus we do not call
// releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing.
- unrolled.dispose()
+ unrolledBuffer.dispose()
+ }
+ }
+
+ // Exposed for testing
+ private[storage] def getUnrolledChunkedByteBuffer: ChunkedByteBuffer = unrolledBuffer
+
+ private[this] var discarded = false
+ private[this] var consumed = false
+
+ private def verifyNotConsumedAndNotDiscarded(): Unit = {
+ if (consumed) {
+ throw new IllegalStateException(
+ "Can only call one of finishWritingToStream() or valuesIterator() and can only call once.")
+ }
+ if (discarded) {
+ throw new IllegalStateException("Cannot call methods on a discarded PartiallySerializedBlock")
}
}
@@ -759,15 +788,18 @@ private[storage] class PartiallySerializedBlock[T](
* Called to dispose of this block and free its memory.
*/
def discard(): Unit = {
- try {
- // We want to close the output stream in order to free any resources associated with the
- // serializer itself (such as Kryo's internal buffers). close() might cause data to be
- // written, so redirect the output stream to discard that data.
- redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
- serializationStream.close()
- } finally {
- unrolled.dispose()
- memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+ if (!discarded) {
+ try {
+ // We want to close the output stream in order to free any resources associated with the
+ // serializer itself (such as Kryo's internal buffers). close() might cause data to be
+ // written, so redirect the output stream to discard that data.
+ redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
+ serializationStream.close()
+ } finally {
+ discarded = true
+ unrolledBuffer.dispose()
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+ }
}
}
@@ -776,8 +808,10 @@ private[storage] class PartiallySerializedBlock[T](
* and then serializing the values from the original input iterator.
*/
def finishWritingToStream(os: OutputStream): Unit = {
+ verifyNotConsumedAndNotDiscarded()
+ consumed = true
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
- ByteStreams.copy(unrolled.toInputStream(dispose = true), os)
+ ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), os)
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
redirectableOutputStream.setOutputStream(os)
while (rest.hasNext) {
@@ -794,13 +828,22 @@ private[storage] class PartiallySerializedBlock[T](
* `close()` on it to free its resources.
*/
def valuesIterator: PartiallyUnrolledIterator[T] = {
+ verifyNotConsumedAndNotDiscarded()
+ consumed = true
+ // Close the serialization stream so that the serializer's internal buffers are freed and any
+ // "end-of-stream" markers can be written out so that `unrolled` is a valid serialized stream.
+ serializationStream.close()
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
val unrolledIter = serializerManager.dataDeserializeStream(
- blockId, unrolled.toInputStream(dispose = true))(classTag)
+ blockId, unrolledBuffer.toInputStream(dispose = true))(classTag)
+ // The unroll memory will be freed once `unrolledIter` is fully consumed in
+ // PartiallyUnrolledIterator. If the iterator is not consumed by the end of the task then any
+ // extra unroll memory will automatically be freed by a `finally` block in `Task`.
new PartiallyUnrolledIterator(
memoryStore,
+ memoryMode,
unrollMemory,
- unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()),
+ unrolled = unrolledIter,
rest = rest)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
index 09e7579ae9..9077b86f9b 100644
--- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
@@ -29,7 +29,32 @@ private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutp
def getCount(): Int = count
+ private[this] var closed: Boolean = false
+
+ override def write(b: Int): Unit = {
+ require(!closed, "cannot write to a closed ByteBufferOutputStream")
+ super.write(b)
+ }
+
+ override def write(b: Array[Byte], off: Int, len: Int): Unit = {
+ require(!closed, "cannot write to a closed ByteBufferOutputStream")
+ super.write(b, off, len)
+ }
+
+ override def reset(): Unit = {
+ require(!closed, "cannot reset a closed ByteBufferOutputStream")
+ super.reset()
+ }
+
+ override def close(): Unit = {
+ if (!closed) {
+ super.close()
+ closed = true
+ }
+ }
+
def toByteBuffer: ByteBuffer = {
- return ByteBuffer.wrap(buf, 0, count)
+ require(closed, "can only call toByteBuffer() after ByteBufferOutputStream has been closed")
+ ByteBuffer.wrap(buf, 0, count)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
index 67b50d1e70..a625b32895 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
@@ -49,10 +49,19 @@ private[spark] class ChunkedByteBufferOutputStream(
*/
private[this] var position = chunkSize
private[this] var _size = 0
+ private[this] var closed: Boolean = false
def size: Long = _size
+ override def close(): Unit = {
+ if (!closed) {
+ super.close()
+ closed = true
+ }
+ }
+
override def write(b: Int): Unit = {
+ require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
allocateNewChunkIfNeeded()
chunks(lastChunkIndex).put(b.toByte)
position += 1
@@ -60,6 +69,7 @@ private[spark] class ChunkedByteBufferOutputStream(
}
override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
+ require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
var written = 0
while (written < len) {
allocateNewChunkIfNeeded()
@@ -73,7 +83,6 @@ private[spark] class ChunkedByteBufferOutputStream(
@inline
private def allocateNewChunkIfNeeded(): Unit = {
- require(!toChunkedByteBufferWasCalled, "cannot write after toChunkedByteBuffer() is called")
if (position == chunkSize) {
chunks += allocator(chunkSize)
lastChunkIndex += 1
@@ -82,6 +91,7 @@ private[spark] class ChunkedByteBufferOutputStream(
}
def toChunkedByteBuffer: ChunkedByteBuffer = {
+ require(closed, "cannot call toChunkedByteBuffer() unless close() has been called")
require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once")
toChunkedByteBufferWasCalled = true
if (lastChunkIndex == -1) {
diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
index c11de82667..9929ea033a 100644
--- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
@@ -79,6 +79,13 @@ class MemoryStoreSuite
(memoryStore, blockInfoManager)
}
+ private def assertSameContents[T](expected: Seq[T], actual: Seq[T], hint: String): Unit = {
+ assert(actual.length === expected.length, s"wrong number of values returned in $hint")
+ expected.iterator.zip(actual.iterator).foreach { case (e, a) =>
+ assert(e === a, s"$hint did not return original values!")
+ }
+ }
+
test("reserve/release unroll memory") {
val (memoryStore, _) = makeMemoryStore(12000)
assert(memoryStore.currentUnrollMemory === 0)
@@ -137,9 +144,7 @@ class MemoryStoreSuite
var putResult = putIteratorAsValues("unroll", smallList.iterator, ClassTag.Any)
assert(putResult.isRight)
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
- smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
- assert(e === a, "getValues() did not return original values!")
- }
+ assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues")
blockInfoManager.lockForWriting("unroll")
assert(memoryStore.remove("unroll"))
blockInfoManager.removeBlock("unroll")
@@ -152,9 +157,7 @@ class MemoryStoreSuite
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
assert(memoryStore.contains("someBlock2"))
assert(!memoryStore.contains("someBlock1"))
- smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
- assert(e === a, "getValues() did not return original values!")
- }
+ assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues")
blockInfoManager.lockForWriting("unroll")
assert(memoryStore.remove("unroll"))
blockInfoManager.removeBlock("unroll")
@@ -167,9 +170,7 @@ class MemoryStoreSuite
assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
assert(!memoryStore.contains("someBlock2"))
assert(putResult.isLeft)
- bigList.iterator.zip(putResult.left.get).foreach { case (e, a) =>
- assert(e === a, "putIterator() did not return original values!")
- }
+ assertSameContents(bigList, putResult.left.get.toSeq, "putIterator")
// The unroll memory was freed once the iterator returned by putIterator() was fully traversed.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
}
@@ -316,9 +317,8 @@ class MemoryStoreSuite
assert(res.isLeft)
assert(memoryStore.currentUnrollMemoryForThisTask > 0)
val valuesReturnedFromFailedPut = res.left.get.valuesIterator.toSeq // force materialization
- valuesReturnedFromFailedPut.zip(bigList).foreach { case (e, a) =>
- assert(e === a, "PartiallySerializedBlock.valuesIterator() did not return original values!")
- }
+ assertSameContents(
+ bigList, valuesReturnedFromFailedPut, "PartiallySerializedBlock.valuesIterator()")
// The unroll memory was freed once the iterator was fully traversed.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
}
@@ -340,12 +340,10 @@ class MemoryStoreSuite
res.left.get.finishWritingToStream(bos)
// The unroll memory was freed once the block was fully written.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
- val deserializationStream = serializerManager.dataDeserializeStream[Any](
- "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any)
- deserializationStream.zip(bigList.iterator).foreach { case (e, a) =>
- assert(e === a,
- "PartiallySerializedBlock.finishWritingtoStream() did not write original values!")
- }
+ val deserializedValues = serializerManager.dataDeserializeStream[Any](
+ "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any).toSeq
+ assertSameContents(
+ bigList, deserializedValues, "PartiallySerializedBlock.finishWritingToStream()")
}
test("multiple unrolls by the same thread") {
diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
new file mode 100644
index 0000000000..ec4f2637fa
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
+
+import org.mockito.Mockito
+import org.mockito.Mockito.atLeastOnce
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
+
+import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl}
+import org.apache.spark.memory.MemoryMode
+import org.apache.spark.serializer.{JavaSerializer, SerializationStream, SerializerManager}
+import org.apache.spark.storage.memory.{MemoryStore, PartiallySerializedBlock, RedirectableOutputStream}
+import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream}
+import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
+
+class PartiallySerializedBlockSuite
+ extends SparkFunSuite
+ with BeforeAndAfterEach
+ with PrivateMethodTester {
+
+ private val blockId = new TestBlockId("test")
+ private val conf = new SparkConf()
+ private val memoryStore = Mockito.mock(classOf[MemoryStore], Mockito.RETURNS_SMART_NULLS)
+ private val serializerManager = new SerializerManager(new JavaSerializer(conf), conf)
+
+ private val getSerializationStream = PrivateMethod[SerializationStream]('serializationStream)
+ private val getRedirectableOutputStream =
+ PrivateMethod[RedirectableOutputStream]('redirectableOutputStream)
+
+ override protected def beforeEach(): Unit = {
+ super.beforeEach()
+ Mockito.reset(memoryStore)
+ }
+
+ private def partiallyUnroll[T: ClassTag](
+ iter: Iterator[T],
+ numItemsToBuffer: Int): PartiallySerializedBlock[T] = {
+
+ val bbos: ChunkedByteBufferOutputStream = {
+ val spy = Mockito.spy(new ChunkedByteBufferOutputStream(128, ByteBuffer.allocate))
+ Mockito.doAnswer(new Answer[ChunkedByteBuffer] {
+ override def answer(invocationOnMock: InvocationOnMock): ChunkedByteBuffer = {
+ Mockito.spy(invocationOnMock.callRealMethod().asInstanceOf[ChunkedByteBuffer])
+ }
+ }).when(spy).toChunkedByteBuffer
+ spy
+ }
+
+ val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance()
+ val redirectableOutputStream = Mockito.spy(new RedirectableOutputStream)
+ redirectableOutputStream.setOutputStream(bbos)
+ val serializationStream = Mockito.spy(serializer.serializeStream(redirectableOutputStream))
+
+ (1 to numItemsToBuffer).foreach { _ =>
+ assert(iter.hasNext)
+ serializationStream.writeObject[T](iter.next())
+ }
+
+ val unrollMemory = bbos.size
+ new PartiallySerializedBlock[T](
+ memoryStore,
+ serializerManager,
+ blockId,
+ serializationStream = serializationStream,
+ redirectableOutputStream,
+ unrollMemory = unrollMemory,
+ memoryMode = MemoryMode.ON_HEAP,
+ bbos,
+ rest = iter,
+ classTag = implicitly[ClassTag[T]])
+ }
+
+ test("valuesIterator() and finishWritingToStream() cannot be called after discard() is called") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.discard()
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.finishWritingToStream(null)
+ }
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.valuesIterator
+ }
+ }
+
+ test("discard() can be called more than once") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.discard()
+ partiallySerializedBlock.discard()
+ }
+
+ test("cannot call valuesIterator() more than once") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.valuesIterator
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.valuesIterator
+ }
+ }
+
+ test("cannot call finishWritingToStream() more than once") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+ }
+ }
+
+ test("cannot call finishWritingToStream() after valuesIterator()") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.valuesIterator
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+ }
+ }
+
+ test("cannot call valuesIterator() after finishWritingToStream()") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.valuesIterator
+ }
+ }
+
+ test("buffers are deallocated in a TaskCompletionListener") {
+ try {
+ TaskContext.setTaskContext(TaskContext.empty())
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted()
+ Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose()
+ Mockito.verifyNoMoreInteractions(memoryStore)
+ } finally {
+ TaskContext.unset()
+ }
+ }
+
+ private def testUnroll[T: ClassTag](
+ testCaseName: String,
+ items: Seq[T],
+ numItemsToBuffer: Int): Unit = {
+
+ test(s"$testCaseName with discard() and numBuffered = $numItemsToBuffer") {
+ val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer)
+ partiallySerializedBlock.discard()
+
+ Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+ MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+ Mockito.verifyNoMoreInteractions(memoryStore)
+ Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose()
+ }
+
+ test(s"$testCaseName with finishWritingToStream() and numBuffered = $numItemsToBuffer") {
+ val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer)
+ val bbos = Mockito.spy(new ByteBufferOutputStream())
+ partiallySerializedBlock.finishWritingToStream(bbos)
+
+ Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+ MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+ Mockito.verify(bbos).close()
+ Mockito.verifyNoMoreInteractions(memoryStore)
+ Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose()
+
+ val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance()
+ val deserialized =
+ serializer.deserializeStream(new ByteBufferInputStream(bbos.toByteBuffer)).asIterator.toSeq
+ assert(deserialized === items)
+ }
+
+ test(s"$testCaseName with valuesIterator() and numBuffered = $numItemsToBuffer") {
+ val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer)
+ val valuesIterator = partiallySerializedBlock.valuesIterator
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+
+ val deserializedItems = valuesIterator.toArray.toSeq
+ Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+ MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+ Mockito.verifyNoMoreInteractions(memoryStore)
+ Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose()
+ assert(deserializedItems === items)
+ }
+ }
+
+ testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 50)
+ testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 0)
+ testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 1000)
+ testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 50)
+ testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 0)
+ testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 1000)
+ testUnroll("empty iterator", Seq.empty[String], numItemsToBuffer = 0)
+}
+
+private case class MyCaseClass(str: String)
diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
index 02c2331dc3..4253cc8ca4 100644
--- a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
@@ -33,7 +33,7 @@ class PartiallyUnrolledIteratorSuite extends SparkFunSuite with MockitoSugar {
val rest = (unrollSize until restSize + unrollSize).iterator
val memoryStore = mock[MemoryStore]
- val joinIterator = new PartiallyUnrolledIterator(memoryStore, unrollSize, unroll, rest)
+ val joinIterator = new PartiallyUnrolledIterator(memoryStore, ON_HEAP, unrollSize, unroll, rest)
// Firstly iterate over unrolling memory iterator
(0 until unrollSize).foreach { value =>
diff --git a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
index 226622075a..8696174567 100644
--- a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
@@ -28,12 +28,14 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
test("empty output") {
val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate)
+ o.close()
assert(o.toChunkedByteBuffer.size === 0)
}
test("write a single byte") {
val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate)
o.write(10)
+ o.close()
val chunkedByteBuffer = o.toChunkedByteBuffer
assert(chunkedByteBuffer.getChunks().length === 1)
assert(chunkedByteBuffer.getChunks().head.array().toSeq === Seq(10.toByte))
@@ -43,6 +45,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(new Array[Byte](9))
o.write(99)
+ o.close()
val chunkedByteBuffer = o.toChunkedByteBuffer
assert(chunkedByteBuffer.getChunks().length === 1)
assert(chunkedByteBuffer.getChunks().head.array()(9) === 99.toByte)
@@ -52,6 +55,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(new Array[Byte](10))
o.write(99)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 2)
assert(arrays(1).length === 1)
@@ -63,6 +67,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
Random.nextBytes(ref)
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 1)
assert(arrays.head.length === ref.length)
@@ -74,6 +79,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
Random.nextBytes(ref)
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 1)
assert(arrays.head.length === ref.length)
@@ -85,6 +91,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
Random.nextBytes(ref)
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 3)
assert(arrays(0).length === 10)
@@ -101,6 +108,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
Random.nextBytes(ref)
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 3)
assert(arrays(0).length === 10)