aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala/org
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-09-17 11:46:15 -0700
committerJosh Rosen <joshrosen@databricks.com>2016-09-17 11:46:15 -0700
commit8faa5217b44e8d52eab7eb2d53d0652abaaf43cd (patch)
treedaf1a90737024c0dccd567f66a8b13ee0f2d3c1a /core/src/test/scala/org
parent86c2d393a56bf1e5114bc5a781253c0460efb8af (diff)
downloadspark-8faa5217b44e8d52eab7eb2d53d0652abaaf43cd.tar.gz
spark-8faa5217b44e8d52eab7eb2d53d0652abaaf43cd.tar.bz2
spark-8faa5217b44e8d52eab7eb2d53d0652abaaf43cd.zip
[SPARK-17491] Close serialization stream to fix wrong answer bug in putIteratorAsBytes()
## What changes were proposed in this pull request? `MemoryStore.putIteratorAsBytes()` may silently lose values when used with `KryoSerializer` because it does not properly close the serialization stream before attempting to deserialize the already-serialized values, which may cause values buffered in Kryo's internal buffers to not be read. This is the root cause behind a user-reported "wrong answer" bug in PySpark caching reported by bennoleslie on the Spark user mailing list in a thread titled "pyspark persist MEMORY_ONLY vs MEMORY_AND_DISK". Due to Spark 2.0's automatic use of KryoSerializer for "safe" types (such as byte arrays, primitives, etc.) this misuse of serializers manifested itself as silent data corruption rather than a StreamCorrupted error (which you might get from JavaSerializer). The minimal fix, implemented here, is to close the serialization stream before attempting to deserialize written values. In addition, this patch adds several additional assertions / precondition checks to prevent misuse of `PartiallySerializedBlock` and `ChunkedByteBufferOutputStream`. ## How was this patch tested? The original bug was masked by an invalid assert in the memory store test cases: the old assert compared two results record-by-record with `zip` but didn't first check that the lengths of the two collections were equal, causing missing records to go unnoticed. The updated test case reproduced this bug. In addition, I added a new `PartiallySerializedBlockSuite` to unit test that component. Author: Josh Rosen <joshrosen@databricks.com> Closes #15043 from JoshRosen/partially-serialized-block-values-iterator-bugfix.
Diffstat (limited to 'core/src/test/scala/org')
-rw-r--r--core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala34
-rw-r--r--core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala215
-rw-r--r--core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala8
4 files changed, 240 insertions, 19 deletions
diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
index c11de82667..9929ea033a 100644
--- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
@@ -79,6 +79,13 @@ class MemoryStoreSuite
(memoryStore, blockInfoManager)
}
+ private def assertSameContents[T](expected: Seq[T], actual: Seq[T], hint: String): Unit = {
+ assert(actual.length === expected.length, s"wrong number of values returned in $hint")
+ expected.iterator.zip(actual.iterator).foreach { case (e, a) =>
+ assert(e === a, s"$hint did not return original values!")
+ }
+ }
+
test("reserve/release unroll memory") {
val (memoryStore, _) = makeMemoryStore(12000)
assert(memoryStore.currentUnrollMemory === 0)
@@ -137,9 +144,7 @@ class MemoryStoreSuite
var putResult = putIteratorAsValues("unroll", smallList.iterator, ClassTag.Any)
assert(putResult.isRight)
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
- smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
- assert(e === a, "getValues() did not return original values!")
- }
+ assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues")
blockInfoManager.lockForWriting("unroll")
assert(memoryStore.remove("unroll"))
blockInfoManager.removeBlock("unroll")
@@ -152,9 +157,7 @@ class MemoryStoreSuite
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
assert(memoryStore.contains("someBlock2"))
assert(!memoryStore.contains("someBlock1"))
- smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
- assert(e === a, "getValues() did not return original values!")
- }
+ assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues")
blockInfoManager.lockForWriting("unroll")
assert(memoryStore.remove("unroll"))
blockInfoManager.removeBlock("unroll")
@@ -167,9 +170,7 @@ class MemoryStoreSuite
assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
assert(!memoryStore.contains("someBlock2"))
assert(putResult.isLeft)
- bigList.iterator.zip(putResult.left.get).foreach { case (e, a) =>
- assert(e === a, "putIterator() did not return original values!")
- }
+ assertSameContents(bigList, putResult.left.get.toSeq, "putIterator")
// The unroll memory was freed once the iterator returned by putIterator() was fully traversed.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
}
@@ -316,9 +317,8 @@ class MemoryStoreSuite
assert(res.isLeft)
assert(memoryStore.currentUnrollMemoryForThisTask > 0)
val valuesReturnedFromFailedPut = res.left.get.valuesIterator.toSeq // force materialization
- valuesReturnedFromFailedPut.zip(bigList).foreach { case (e, a) =>
- assert(e === a, "PartiallySerializedBlock.valuesIterator() did not return original values!")
- }
+ assertSameContents(
+ bigList, valuesReturnedFromFailedPut, "PartiallySerializedBlock.valuesIterator()")
// The unroll memory was freed once the iterator was fully traversed.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
}
@@ -340,12 +340,10 @@ class MemoryStoreSuite
res.left.get.finishWritingToStream(bos)
// The unroll memory was freed once the block was fully written.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
- val deserializationStream = serializerManager.dataDeserializeStream[Any](
- "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any)
- deserializationStream.zip(bigList.iterator).foreach { case (e, a) =>
- assert(e === a,
- "PartiallySerializedBlock.finishWritingtoStream() did not write original values!")
- }
+ val deserializedValues = serializerManager.dataDeserializeStream[Any](
+ "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any).toSeq
+ assertSameContents(
+ bigList, deserializedValues, "PartiallySerializedBlock.finishWritingToStream()")
}
test("multiple unrolls by the same thread") {
diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
new file mode 100644
index 0000000000..ec4f2637fa
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
+
+import org.mockito.Mockito
+import org.mockito.Mockito.atLeastOnce
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
+
+import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl}
+import org.apache.spark.memory.MemoryMode
+import org.apache.spark.serializer.{JavaSerializer, SerializationStream, SerializerManager}
+import org.apache.spark.storage.memory.{MemoryStore, PartiallySerializedBlock, RedirectableOutputStream}
+import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream}
+import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
+
+class PartiallySerializedBlockSuite
+ extends SparkFunSuite
+ with BeforeAndAfterEach
+ with PrivateMethodTester {
+
+ private val blockId = new TestBlockId("test")
+ private val conf = new SparkConf()
+ private val memoryStore = Mockito.mock(classOf[MemoryStore], Mockito.RETURNS_SMART_NULLS)
+ private val serializerManager = new SerializerManager(new JavaSerializer(conf), conf)
+
+ private val getSerializationStream = PrivateMethod[SerializationStream]('serializationStream)
+ private val getRedirectableOutputStream =
+ PrivateMethod[RedirectableOutputStream]('redirectableOutputStream)
+
+ override protected def beforeEach(): Unit = {
+ super.beforeEach()
+ Mockito.reset(memoryStore)
+ }
+
+ private def partiallyUnroll[T: ClassTag](
+ iter: Iterator[T],
+ numItemsToBuffer: Int): PartiallySerializedBlock[T] = {
+
+ val bbos: ChunkedByteBufferOutputStream = {
+ val spy = Mockito.spy(new ChunkedByteBufferOutputStream(128, ByteBuffer.allocate))
+ Mockito.doAnswer(new Answer[ChunkedByteBuffer] {
+ override def answer(invocationOnMock: InvocationOnMock): ChunkedByteBuffer = {
+ Mockito.spy(invocationOnMock.callRealMethod().asInstanceOf[ChunkedByteBuffer])
+ }
+ }).when(spy).toChunkedByteBuffer
+ spy
+ }
+
+ val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance()
+ val redirectableOutputStream = Mockito.spy(new RedirectableOutputStream)
+ redirectableOutputStream.setOutputStream(bbos)
+ val serializationStream = Mockito.spy(serializer.serializeStream(redirectableOutputStream))
+
+ (1 to numItemsToBuffer).foreach { _ =>
+ assert(iter.hasNext)
+ serializationStream.writeObject[T](iter.next())
+ }
+
+ val unrollMemory = bbos.size
+ new PartiallySerializedBlock[T](
+ memoryStore,
+ serializerManager,
+ blockId,
+ serializationStream = serializationStream,
+ redirectableOutputStream,
+ unrollMemory = unrollMemory,
+ memoryMode = MemoryMode.ON_HEAP,
+ bbos,
+ rest = iter,
+ classTag = implicitly[ClassTag[T]])
+ }
+
+ test("valuesIterator() and finishWritingToStream() cannot be called after discard() is called") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.discard()
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.finishWritingToStream(null)
+ }
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.valuesIterator
+ }
+ }
+
+ test("discard() can be called more than once") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.discard()
+ partiallySerializedBlock.discard()
+ }
+
+ test("cannot call valuesIterator() more than once") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.valuesIterator
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.valuesIterator
+ }
+ }
+
+ test("cannot call finishWritingToStream() more than once") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+ }
+ }
+
+ test("cannot call finishWritingToStream() after valuesIterator()") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.valuesIterator
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+ }
+ }
+
+ test("cannot call valuesIterator() after finishWritingToStream()") {
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream())
+ intercept[IllegalStateException] {
+ partiallySerializedBlock.valuesIterator
+ }
+ }
+
+ test("buffers are deallocated in a TaskCompletionListener") {
+ try {
+ TaskContext.setTaskContext(TaskContext.empty())
+ val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+ TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted()
+ Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose()
+ Mockito.verifyNoMoreInteractions(memoryStore)
+ } finally {
+ TaskContext.unset()
+ }
+ }
+
+ private def testUnroll[T: ClassTag](
+ testCaseName: String,
+ items: Seq[T],
+ numItemsToBuffer: Int): Unit = {
+
+ test(s"$testCaseName with discard() and numBuffered = $numItemsToBuffer") {
+ val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer)
+ partiallySerializedBlock.discard()
+
+ Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+ MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+ Mockito.verifyNoMoreInteractions(memoryStore)
+ Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose()
+ }
+
+ test(s"$testCaseName with finishWritingToStream() and numBuffered = $numItemsToBuffer") {
+ val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer)
+ val bbos = Mockito.spy(new ByteBufferOutputStream())
+ partiallySerializedBlock.finishWritingToStream(bbos)
+
+ Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+ MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+ Mockito.verify(bbos).close()
+ Mockito.verifyNoMoreInteractions(memoryStore)
+ Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose()
+
+ val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance()
+ val deserialized =
+ serializer.deserializeStream(new ByteBufferInputStream(bbos.toByteBuffer)).asIterator.toSeq
+ assert(deserialized === items)
+ }
+
+ test(s"$testCaseName with valuesIterator() and numBuffered = $numItemsToBuffer") {
+ val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer)
+ val valuesIterator = partiallySerializedBlock.valuesIterator
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+ Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+
+ val deserializedItems = valuesIterator.toArray.toSeq
+ Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+ MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+ Mockito.verifyNoMoreInteractions(memoryStore)
+ Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose()
+ assert(deserializedItems === items)
+ }
+ }
+
+ testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 50)
+ testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 0)
+ testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 1000)
+ testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 50)
+ testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 0)
+ testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 1000)
+ testUnroll("empty iterator", Seq.empty[String], numItemsToBuffer = 0)
+}
+
+private case class MyCaseClass(str: String)
diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
index 02c2331dc3..4253cc8ca4 100644
--- a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
@@ -33,7 +33,7 @@ class PartiallyUnrolledIteratorSuite extends SparkFunSuite with MockitoSugar {
val rest = (unrollSize until restSize + unrollSize).iterator
val memoryStore = mock[MemoryStore]
- val joinIterator = new PartiallyUnrolledIterator(memoryStore, unrollSize, unroll, rest)
+ val joinIterator = new PartiallyUnrolledIterator(memoryStore, ON_HEAP, unrollSize, unroll, rest)
// Firstly iterate over unrolling memory iterator
(0 until unrollSize).foreach { value =>
diff --git a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
index 226622075a..8696174567 100644
--- a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
@@ -28,12 +28,14 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
test("empty output") {
val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate)
+ o.close()
assert(o.toChunkedByteBuffer.size === 0)
}
test("write a single byte") {
val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate)
o.write(10)
+ o.close()
val chunkedByteBuffer = o.toChunkedByteBuffer
assert(chunkedByteBuffer.getChunks().length === 1)
assert(chunkedByteBuffer.getChunks().head.array().toSeq === Seq(10.toByte))
@@ -43,6 +45,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(new Array[Byte](9))
o.write(99)
+ o.close()
val chunkedByteBuffer = o.toChunkedByteBuffer
assert(chunkedByteBuffer.getChunks().length === 1)
assert(chunkedByteBuffer.getChunks().head.array()(9) === 99.toByte)
@@ -52,6 +55,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(new Array[Byte](10))
o.write(99)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 2)
assert(arrays(1).length === 1)
@@ -63,6 +67,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
Random.nextBytes(ref)
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 1)
assert(arrays.head.length === ref.length)
@@ -74,6 +79,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
Random.nextBytes(ref)
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 1)
assert(arrays.head.length === ref.length)
@@ -85,6 +91,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
Random.nextBytes(ref)
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 3)
assert(arrays(0).length === 10)
@@ -101,6 +108,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
Random.nextBytes(ref)
val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
+ o.close()
val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 3)
assert(arrays(0).length === 10)