aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala89
-rw-r--r--core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala12
-rw-r--r--core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala34
-rw-r--r--core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala215
-rw-r--r--core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala8
8 files changed, 344 insertions, 44 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 35c4dafe9c..1ed36bf069 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -230,6 +230,7 @@ private[spark] object Task {
dataOut.flush()
val taskBytes = serializer.serialize(task)
Utils.writeByteBuffer(taskBytes, out)
+ out.close()
out.toByteBuffer
}
diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index ec1b0f7149..205d469f48 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -33,7 +33,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.serializer.{SerializationStream, SerializerManager}
import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
import org.apache.spark.unsafe.Platform
-import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
+import org.apache.spark.util.{SizeEstimator, Utils}
import org.apache.spark.util.collection.SizeTrackingVector
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
@@ -277,6 +277,7 @@ private[spark] class MemoryStore(
"released too much unroll memory")
Left(new PartiallyUnrolledIterator(
this,
+ MemoryMode.ON_HEAP,
unrollMemoryUsedByThisBlock,
unrolled = arrayValues.toIterator,
rest = Iterator.empty))
@@ -285,7 +286,11 @@ private[spark] class MemoryStore(
// We ran out of space while unrolling the values for this block
logUnrollFailureMessage(blockId, vector.estimateSize())
Left(new PartiallyUnrolledIterator(
- this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values))
+ this,
+ MemoryMode.ON_HEAP,
+ unrollMemoryUsedByThisBlock,
+ unrolled = vector.iterator,
+ rest = values))
}
}
@@ -394,7 +399,7 @@ private[spark] class MemoryStore(
redirectableStream,
unrollMemoryUsedByThisBlock,
memoryMode,
- bbos.toChunkedByteBuffer,
+ bbos,
values,
classTag))
}
@@ -655,6 +660,7 @@ private[spark] class MemoryStore(
* The result of a failed [[MemoryStore.putIteratorAsValues()]] call.
*
* @param memoryStore the memoryStore, used for freeing memory.
+ * @param memoryMode the memory mode (on- or off-heap).
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param unrolled an iterator for the partially-unrolled values.
* @param rest the rest of the original iterator passed to
@@ -662,13 +668,14 @@ private[spark] class MemoryStore(
*/
private[storage] class PartiallyUnrolledIterator[T](
memoryStore: MemoryStore,
+ memoryMode: MemoryMode,
unrollMemory: Long,
private[this] var unrolled: Iterator[T],
rest: Iterator[T])
extends Iterator[T] {
private def releaseUnrollMemory(): Unit = {
- memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
// SPARK-17503: Garbage collects the unrolling memory before the life end of
// PartiallyUnrolledIterator.
unrolled = null
@@ -706,7 +713,7 @@ private[storage] class PartiallyUnrolledIterator[T](
/**
* A wrapper which allows an open [[OutputStream]] to be redirected to a different sink.
*/
-private class RedirectableOutputStream extends OutputStream {
+private[storage] class RedirectableOutputStream extends OutputStream {
private[this] var os: OutputStream = _
def setOutputStream(s: OutputStream): Unit = { os = s }
override def write(b: Int): Unit = os.write(b)
@@ -726,7 +733,8 @@ private class RedirectableOutputStream extends OutputStream {
* @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param memoryMode whether the unroll memory is on- or off-heap
- * @param unrolled a byte buffer containing the partially-serialized values.
+ * @param bbos byte buffer output stream containing the partially-serialized values.
+ * [[redirectableOutputStream]] initially points to this output stream.
* @param rest the rest of the original iterator passed to
* [[MemoryStore.putIteratorAsValues()]].
* @param classTag the [[ClassTag]] for the block.
@@ -735,14 +743,19 @@ private[storage] class PartiallySerializedBlock[T](
memoryStore: MemoryStore,
serializerManager: SerializerManager,
blockId: BlockId,
- serializationStream: SerializationStream,
- redirectableOutputStream: RedirectableOutputStream,
- unrollMemory: Long,
+ private val serializationStream: SerializationStream,
+ private val redirectableOutputStream: RedirectableOutputStream,
+ val unrollMemory: Long,
memoryMode: MemoryMode,
- unrolled: ChunkedByteBuffer,
+ bbos: ChunkedByteBufferOutputStream,
rest: Iterator[T],
classTag: ClassTag[T]) {
+ private lazy val unrolledBuffer: ChunkedByteBuffer = {
+ bbos.close()
+ bbos.toChunkedByteBuffer
+ }
+
// If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of
// this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task
// completion listener here in order to ensure that `unrolled.dispose()` is called at least once.
@@ -751,7 +764,23 @@ private[storage] class PartiallySerializedBlock[T](
taskContext.addTaskCompletionListener { _ =>
// When a task completes, its unroll memory will automatically be freed. Thus we do not call
// releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing.
- unrolled.dispose()
+ unrolledBuffer.dispose()
+ }
+ }
+
+ // Exposed for testing
+ private[storage] def getUnrolledChunkedByteBuffer: ChunkedByteBuffer = unrolledBuffer
+
+ private[this] var discarded = false
+ private[this] var consumed = false
+
+ private def verifyNotConsumedAndNotDiscarded(): Unit = {
+ if (consumed) {
+ throw new IllegalStateException(
+ "Can only call one of finishWritingToStream() or valuesIterator() and can only call once.")
+ }
+ if (discarded) {
+ throw new IllegalStateException("Cannot call methods on a discarded PartiallySerializedBlock")
}
}
@@ -759,15 +788,18 @@ private[storage] class PartiallySerializedBlock[T](
* Called to dispose of this block and free its memory.
*/
def discard(): Unit = {
- try {
- // We want to close the output stream in order to free any resources associated with the
- // serializer itself (such as Kryo's internal buffers). close() might cause data to be
- // written, so redirect the output stream to discard that data.
- redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
- serializationStream.close()
- } finally {
- unrolled.dispose()
- memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+ if (!discarded) {
+ try {
+ // We want to close the output stream in order to free any resources associated with the
+ // serializer itself (such as Kryo's internal buffers). close() might cause data to be
+ // written, so redirect the output stream to discard that data.
+ redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
+ serializationStream.close()
+ } finally {
+ discarded = true
+ unrolledBuffer.dispose()
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+ }
}
}
@@ -776,8 +808,10 @@ private[storage] class PartiallySerializedBlock[T](
* and then serializing the values from the original input iterator.
*/
def finishWritingToStream(os: OutputStream): Unit = {
+ verifyNotConsumedAndNotDiscarded()
+ consumed = true
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
- ByteStreams.copy(unrolled.toInputStream(dispose = true), os)
+ ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), os)
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
redirectableOutputStream.setOutputStream(os)
while (rest.hasNext) {
@@ -794,13 +828,22 @@ private[storage] class PartiallySerializedBlock[T](
* `close()` on it to free its resources.
*/
def valuesIterator: PartiallyUnrolledIterator[T] = {
+ verifyNotConsumedAndNotDiscarded()
+ consumed = true
+ // Close the serialization stream so that the serializer's internal buffers are freed and any
+ // "end-of-stream" markers can be written out so that `unrolled` is a valid serialized stream.
+ serializationStream.close()
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
val unrolledIter = serializerManager.dataDeserializeStream(
- blockId, unrolled.toInputStream(dispose = true))(classTag)
+ blockId, unrolledBuffer.toInputStream(dispose = true))(classTag)
+ // The unroll memory will be freed once `unrolledIter` is fully consumed in
+ // PartiallyUnrolledIterator. If the iterator is not consumed by the end of the task then any
+ // extra unroll memory will automatically be freed by a `finally` block in `Task`.
new PartiallyUnrolledIterator(
memoryStore,
+ memoryMode,
unrollMemory,
- unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()),
+ unrolled = unrolledIter,
rest = rest)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
index 09e7579ae9..9077b86f9b 100644
--- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
@@ -29,7 +29,32 @@ private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutp
def getCount(): Int = count
+ private[this] var closed: Boolean = false
+
+ override def write(b: Int): Unit = {
+ require(!closed, "cannot write to a closed ByteBufferOutputStream")
+ super.write(b)
+ }
+
+ override def write(b: Array[Byte], off: Int, len: Int): Unit = {
+ require(!closed, "cannot write to a closed ByteBufferOutputStream")
+ super.write(b, off, len)
+ }
+
+ override def reset(): Unit = {
+ require(!closed, "cannot reset a closed ByteBufferOutputStream")
+ super.reset()
+ }
+
+ override def close(): Unit = {
+ if (!closed) {
+ super.close()
+ closed = true
+ }
+ }
+
def toByteBuffer: ByteBuffer = {
- return ByteBuffer.wrap(buf, 0, count)
+ require(closed, "can only call toByteBuffer() after ByteBufferOutputStream has been closed")
+ ByteBuffer.wrap(buf, 0, count)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
index 67b50d1e70..a625b32895 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
@@ -49,10 +49,19 @@ private[spark] class ChunkedByteBufferOutputStream(
*/
private[this] var position = chunkSize
private[this] var _size = 0
+ private[this] var closed: Boolean = false
def size: Long = _size
+ override def close(): Unit = {
+ if (!closed) {
+ super.close()
+ closed = true
+ }
+ }
+
override def write(b: Int): Unit = {
+ require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
allocateNewChunkIfNeeded()
chunks(lastChunkIndex).put(b.toByte)
position += 1
@@ -60,6 +69,7 @@ private[spark] class ChunkedByteBufferOutputStream(
}
override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
+ require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
var written = 0
while (written < len) {
allocateNewChunkIfNeeded()
@@ -73,7 +83,6 @@ private[spark] class ChunkedByteBufferOutputStream(
@inline
private def allocateNewChunkIfNeeded(): Unit = {
- require(!toChunkedByteBufferWasCalled, "cannot write after toChunkedByteBuffer() is called")
if (position == chunkSize) {
chunks += allocator(chunkSize)
lastChunkIndex += 1
@@ -82,6 +91,7 @@ private[spark] class ChunkedByteBufferOutputStream(
}
def toChunkedByteBuffer: ChunkedByteBuffer = {
+ require(closed, "cannot call toChunkedByteBuffer() unless close() has been called")
require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once")
toChunkedByteBufferWasCalled = true
if (lastChunkIndex == -1) {
diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
index c11de82667..9929ea033a 100644
--- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
@@ -79,6 +79,13 @@ class MemoryStoreSuite
(memoryStore, blockInfoManager)
}
+ private def assertSameContents[T](expected: Seq[T], actual: Seq[T], hint: String): Unit = {
+ assert(actual.length === expected.length, s"wrong number of values returned in $hint")
+ expected.iterator.zip(actual.iterator).foreach { case (e, a) =>
+ assert(e === a, s"$hint did not return original values!")
+ }
+ }
+
test("reserve/release unroll memory") {
val (memoryStore, _) = makeMemoryStore(12000)
assert(memoryStore.currentUnrollMemory === 0)
@@ -137,9 +144,7 @@ class MemoryStoreSuite
var putResult = putIteratorAsValues("unroll", smallList.iterator, ClassTag.Any)
assert(putResult.isRight)
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
- smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
- assert(e === a, "getValues() did not return original values!")
- }
+ assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues")
blockInfoManager.lockForWriting("unroll")
assert(memoryStore.remove("unroll"))
blockInfoManager.removeBlock("unroll")
@@ -152,9 +157,7 @@ class MemoryStoreSuite
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
assert(memoryStore.contains("someBlock2"))
assert(!memoryStore.contains("someBlock1"))
- smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
- assert(e === a, "getValues() did not return original values!")
- }
+ assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues")
blockInfoManager.lockForWriting("unroll")
assert(memoryStore.remove("unroll"))
blockInfoManager.removeBlock("unroll")
@@ -167,9 +170,7 @@ class MemoryStoreSuite
assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
assert(!memoryStore.contains("someBlock2"))
assert(putResult.isLeft)
- bigList.iterator.zip(putResult.left.get).foreach { case (e, a) =>
- assert(e === a, "putIterator() did not return original values!")
- }
+ assertSameContents(bigList, putResult.left.get.toSeq, "putIterator")
// The unroll memory was freed once the iterator returned by putIterator() was fully traversed.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
}
@@ -316,9 +317,8 @@ class MemoryStoreSuite
assert(res.isLeft)
assert(memoryStore.currentUnrollMemoryForThisTask > 0)
val valuesReturnedFromFailedPut = res.left.get.valuesIterator.toSeq // force materialization
- valuesReturnedFromFailedPut.zip(bigList).foreach { case (e, a) =>
- assert(e === a, "PartiallySerializedBlock.valuesIterator() did not return original values!")
- }
+ assertSameContents(
+ bigList, valuesReturnedFromFailedPut, "PartiallySerializedBlock.valuesIterator()")
// The unroll memory was freed once the iterator was fully traversed.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
}
@@ -340,12 +340,10 @@ class MemoryStoreSuite
res.left.get.finishWritingToStream(bos)
// The unroll memory was freed once the block was fully written.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
- val deserializationStream = serializerManager.dataDeserializeStream[Any](
- "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any)
- deserializationStream.zip(bigList.iterator).foreach { case (e, a) =>
- assert(e === a,
- "PartiallySerializedBlock.finishWritingtoStream() did not write original values!")
- }
+ val deserializedValues = serializerManager.dataDeserializeStream[Any](
+ "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any).toSeq
+ assertSameContents(
+ bigList, deserializedValues, "PartiallySerializedBlock.finishWritingToStream()")
}
test("multiple unrolls by the same thread") {
diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
new file mode 100644
index 0000000000..ec4f2637fa
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
+
+import org.mockito.Mockito
+import org.mockito.Mockito.atLeastOnce
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
+
+import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl}
+import org.apache.spark.memory.MemoryMode
+import org.apache.spark.serializer.{JavaSerializer, SerializationStream, SerializerManager}
+import org.apache.spark.storage.memory.{MemoryStore, PartiallySerializedBlock, RedirectableOutputStream}
+import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream}
+import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
+
+class PartiallySerializedBlockSuite
+ extends SparkFunSuite
+ with BeforeAndAfterEach
+ with PrivateMethodTester {
+
+ private val blockId = new TestBlockId("test")
+ private val conf = new SparkConf()
+ private val memoryStore = Mockito.mock(classOf[MemoryStore], Mockito.RETURNS_SMART_NULLS)
+ private val serializerManager = new SerializerManager(new JavaSerializer(conf), conf)
+
+ private val getSerializationStream = PrivateMethod[SerializationStream]('serializationStream)
+ private val getRedirectableOutputStream =
+ PrivateMethod[RedirectableOutputStream]('redirectableOutputStream)
+
+ override protected def beforeEach(): Unit = {
+ super.beforeEach()
+ Mockito.reset(memoryStore)
+ }
+
+ private def partiallyUnroll[T: ClassTag](
+ iter: Iterator[T],
+ numItemsToBuffer: Int): PartiallySerializedBlock[T] = {
+
+ val bbos: ChunkedByteBufferOutputStream = {
+ val spy = Mockito.spy(new ChunkedByteBufferOutputStream(128, ByteBuffer.allocate))
+ Mockito.doAnswer(new Answer[ChunkedByteBuffer] {
+ override def answer(invocationOnMock: InvocationOnMock): ChunkedByteBuffer = {
+ Mockito.spy(invocationOnMock.callRealMethod().asInstanceOf[ChunkedByteBuffer])
+ }
+ }).when(spy).toChunkedByteBuffer
+ spy
+ }
+
+ val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance()
+ val redirectableOutputStream = Mockito.spy(new RedirectableOutputStream)
+ redirectableOutputStream.setOutputStream(bbos)
+ val serializationStream = Mockito.spy(serializer.serializeStream(redirectableOutputStream))
+
+ (1 to numItemsToBuffer).foreach { _ =>
+ assert(iter.hasNext)
+ serializationStream.writeObject[T](iter.next())
+ }
+
+ val unrollMemory = bbos.size
+ new PartiallySerializedBlock[T](
+ memoryStore,
+ serializerManager,
+ blockId,
+ serializationStream = serializationStream,
+ redirectableOutputStream,
+ unrollMemory = unrollMemory,
+ memoryMode = MemoryMode.ON_HEAP,
+ bbos,
+ rest = iter,
+ classTag = implicitly[ClassTag[T]])
+ }
+
+ test("valuesIterator() and finishWritingToStream() cannot be called after discard() is called") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.discard()
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.finishWritingToStream(null)
+ }
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.valuesIterator
+ }
+ }
+
+ test("discard() can be called more than once") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.discard()
+ partiallySerializedBlock.discard()
+ }
+
+ test("cannot call valuesIterator() more than once") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.valuesIterator
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.valuesIterator
+ }
+ }
+
+ test("cannot call finishWritingToStream() more than once") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+ }
+ }
+
+ test("cannot call finishWritingToStream() after valuesIterator()") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.valuesIterator
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+ }
+ }
+
+ test("cannot call valuesIterator() after finishWritingToStream()") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.valuesIterator
+ }
+ }
+
+ test("buffers are deallocated in a TaskCompletionListener") {
+ try {
+ TaskContext.setTaskContext(TaskContext.empty())
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted()
+ Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose()
+ Mockito.verifyNoMoreInteractions(memoryStore)
+ } finally {
+ TaskContext.unset()
+ }
+ }
+
+ private def testUnroll[T: ClassTag](
+ testCaseName: String,
+ items: Seq[T],
+ numItemsToBuffer: Int): Unit = {
+
+ test(s"$testCaseName with discard() and numBuffered = $numItemsToBuffer") {
+ val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer)
+ partiallySerializedBlock.discard()
+
+ Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+ MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+ Mockito.verifyNoMoreInteractions(memoryStore)
+ Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose()
+ }
+
+ test(s"$testCaseName with finishWritingToStream() and numBuffered = $numItemsToBuffer") {
+ val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer)
+ val bbos = Mockito.spy(new ByteBufferOutputStream())
+ partiallySerializedBlock.finishWritingToStream(bbos)
+
+ Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+ MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+ Mockito.verify(bbos).close()
+ Mockito.verifyNoMoreInteractions(memoryStore)
+ Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose()
+
+ val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance()
+ val deserialized =
+ serializer.deserializeStream(new ByteBufferInputStream(bbos.toByteBuffer)).asIterator.toSeq
+ assert(deserialized === items)
+ }
+
+ test(s"$testCaseName with valuesIterator() and numBuffered = $numItemsToBuffer") {
+ val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer)
+ val valuesIterator = partiallySerializedBlock.valuesIterator
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+
+ val deserializedItems = valuesIterator.toArray.toSeq
+ Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+ MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+ Mockito.verifyNoMoreInteractions(memoryStore)
+ Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose()
+ assert(deserializedItems === items)
+ }
+ }
+
+ testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 50)
+ testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 0)
+ testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 1000)
+ testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 50)
+ testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 0)
+ testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 1000)
+ testUnroll("empty iterator", Seq.empty[String], numItemsToBuffer = 0)
+}
+
+private case class MyCaseClass(str: String)
diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
index 02c2331dc3..4253cc8ca4 100644
--- a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
@@ -33,7 +33,7 @@ class PartiallyUnrolledIteratorSuite extends SparkFunSuite with MockitoSugar {
val rest = (unrollSize until restSize + unrollSize).iterator
val memoryStore = mock[MemoryStore]
- val joinIterator = new PartiallyUnrolledIterator(memoryStore, unrollSize, unroll, rest)
+ val joinIterator = new PartiallyUnrolledIterator(memoryStore, ON_HEAP, unrollSize, unroll, rest)
// Firstly iterate over unrolling memory iterator
(0 until unrollSize).foreach { value =>
diff --git a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
index 226622075a..8696174567 100644
--- a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
@@ -28,12 +28,14 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
test("empty output") {
val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate)
+ o.close()
assert(o.toChunkedByteBuffer.size === 0)
}
test("write a single byte") {
val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate)
o.write(10)
+ o.close()
val chunkedByteBuffer = o.toChunkedByteBuffer
assert(chunkedByteBuffer.getChunks().length === 1)
assert(chunkedByteBuffer.getChunks().head.array().toSeq === Seq(10.toByte))
@@ -43,6 +45,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(new Array[Byte](9))
o.write(99)
+ o.close()
val chunkedByteBuffer = o.toChunkedByteBuffer
assert(chunkedByteBuffer.getChunks().length === 1)
assert(chunkedByteBuffer.getChunks().head.array()(9) === 99.toByte)
@@ -52,6 +55,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(new Array[Byte](10))
o.write(99)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 2)
assert(arrays(1).length === 1)
@@ -63,6 +67,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
Random.nextBytes(ref)
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 1)
assert(arrays.head.length === ref.length)
@@ -74,6 +79,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
Random.nextBytes(ref)
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 1)
assert(arrays.head.length === ref.length)
@@ -85,6 +91,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
Random.nextBytes(ref)
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 3)
assert(arrays(0).length === 10)
@@ -101,6 +108,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
Random.nextBytes(ref)
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 3)
assert(arrays(0).length === 10)