From 1ca5e2e0b8d8d406c02a74c76ae9d7fc5637c8d3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Sep 2015 11:50:22 -0700 Subject: [SPARK-10704] Rename HashShuffleReader to BlockStoreShuffleReader The current shuffle code has an interface named ShuffleReader with only one implementation, HashShuffleReader. This naming is confusing, since the same read path code is used for both sort- and hash-based shuffle. This patch addresses this by renaming HashShuffleReader to BlockStoreShuffleReader. Author: Josh Rosen Closes #8825 from JoshRosen/shuffle-reader-cleanup. --- .../spark/shuffle/BlockStoreShuffleReader.scala | 111 +++++++++++++++ .../spark/shuffle/hash/HashShuffleManager.scala | 2 +- .../spark/shuffle/hash/HashShuffleReader.scala | 112 --------------- .../spark/shuffle/sort/SortShuffleManager.scala | 3 +- .../shuffle/BlockStoreShuffleReaderSuite.scala | 153 ++++++++++++++++++++ .../shuffle/hash/HashShuffleReaderSuite.scala | 154 --------------------- 6 files changed, 266 insertions(+), 269 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala delete mode 100644 core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala new file mode 100644 index 0000000000..6dc9a16e58 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark._ +import org.apache.spark.serializer.Serializer +import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter + +private[spark] class BlockStoreShuffleReader[K, C]( + handle: BaseShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) + extends ShuffleReader[K, C] with Logging { + + require(endPartition == startPartition + 1, + "Hash shuffle currently only supports fetching one partition") + + private val dep = handle.dependency + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition), + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + + // Wrap the streams for compression based on configuration + val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => + blockManager.wrapForCompression(blockId, inputStream) + } + + val ser = Serializer.getSerializer(dep.serializer) + val serializerInstance = ser.newInstance() + + // Create a key/value iterator for each stream + val recordIter = wrappedStreams.flatMap { wrappedStream => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map(record => { + readMetrics.incRecordsRead(1) + record + }), + context.taskMetrics().updateShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + if (dep.mapSideCombine) { + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + } else { + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) + } + } else { + require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] + } + + // Sort the output if there is a sort ordering defined. + dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, + // the ExternalSorter won't spill to disk. + val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) + sorter.insertAll(aggregatedIter) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) + sorter.iterator + case None => + aggregatedIter + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 0b46634b8b..d2e2fc4c11 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -51,7 +51,7 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala deleted file mode 100644 index 0c8f08f0f3..0000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import org.apache.spark._ -import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} -import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} -import org.apache.spark.util.CompletionIterator -import org.apache.spark.util.collection.ExternalSorter - -private[spark] class HashShuffleReader[K, C]( - handle: BaseShuffleHandle[K, _, C], - startPartition: Int, - endPartition: Int, - context: TaskContext, - blockManager: BlockManager = SparkEnv.get.blockManager, - mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) - extends ShuffleReader[K, C] with Logging { - - require(endPartition == startPartition + 1, - "Hash shuffle currently only supports fetching one partition") - - private val dep = handle.dependency - - /** Read the combined key-values for this reduce task */ - override def read(): Iterator[Product2[K, C]] = { - val blockFetcherItr = new ShuffleBlockFetcherIterator( - context, - blockManager.shuffleClient, - blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition), - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - - // Wrap the streams for compression based on configuration - val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => - blockManager.wrapForCompression(blockId, inputStream) - } - - val ser = Serializer.getSerializer(dep.serializer) - val serializerInstance = ser.newInstance() - - // Create a key/value iterator for each stream - val recordIter = wrappedStreams.flatMap { wrappedStream => - // Note: the asKeyValueIterator below wraps a key/value iterator inside of a - // NextIterator. The NextIterator makes sure that close() is called on the - // underlying InputStream when all records have been read. - serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator - } - - // Update the context task metrics for each record read. - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( - recordIter.map(record => { - readMetrics.incRecordsRead(1) - record - }), - context.taskMetrics().updateShuffleReadMetrics()) - - // An interruptible iterator must be used here in order to support task cancellation - val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) - - val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { - if (dep.mapSideCombine) { - // We are reading values that are already combined - val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] - dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) - } else { - // We don't know the value type, but also don't care -- the dependency *should* - // have made sure its compatible w/ this aggregator, which will convert the value - // type to the combined type C - val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] - dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) - } - } else { - require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") - interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] - } - - // Sort the output if there is a sort ordering defined. - dep.keyOrdering match { - case Some(keyOrd: Ordering[K]) => - // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, - // the ExternalSorter won't spill to disk. - val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) - sorter.insertAll(aggregatedIter) - context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) - sorter.iterator - case None => - aggregatedIter - } - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 476cc1f303..9df4e55166 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -21,7 +21,6 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.hash.HashShuffleReader private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { @@ -54,7 +53,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { // We currently use the same block store shuffle fetcher as the hash-based shuffle. - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala new file mode 100644 index 0000000000..a5eafb1b55 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.{ByteArrayOutputStream, InputStream} +import java.nio.ByteBuffer + +import org.mockito.Matchers.{eq => meq, _} +import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} + +/** + * Wrapper for a managed buffer that keeps track of how many times retain and release are called. + * + * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class + * is final (final classes cannot be spied on). + */ +class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer { + var callsToRetain = 0 + var callsToRelease = 0 + + override def size(): Long = underlyingBuffer.size() + override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer() + override def createInputStream(): InputStream = underlyingBuffer.createInputStream() + override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty() + + override def retain(): ManagedBuffer = { + callsToRetain += 1 + underlyingBuffer.retain() + } + override def release(): ManagedBuffer = { + callsToRelease += 1 + underlyingBuffer.release() + } +} + +class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { + + /** + * This test makes sure that, when data is read from a HashShuffleReader, the underlying + * ManagedBuffers that contain the data are eventually released. + */ + test("read() releases resources on completion") { + val testConf = new SparkConf(false) + // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the + // shuffle code calls SparkEnv.get()). + sc = new SparkContext("local", "test", testConf) + + val reduceId = 15 + val shuffleId = 22 + val numMaps = 6 + val keyValuePairsPerMap = 10 + val serializer = new JavaSerializer(testConf) + + // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we + // can ensure retain() and release() are properly called. + val blockManager = mock(classOf[BlockManager]) + + // Create a return function to use for the mocked wrapForCompression method that just returns + // the original input stream. + val dummyCompressionFunction = new Answer[InputStream] { + override def answer(invocation: InvocationOnMock): InputStream = + invocation.getArguments()(1).asInstanceOf[InputStream] + } + + // Create a buffer with some randomly generated key-value pairs to use as the shuffle data + // from each mappers (all mappers return the same shuffle data). + val byteOutputStream = new ByteArrayOutputStream() + val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) + (0 until keyValuePairsPerMap).foreach { i => + serializationStream.writeKey(i) + serializationStream.writeValue(2*i) + } + + // Setup the mocked BlockManager to return RecordingManagedBuffers. + val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) + when(blockManager.blockManagerId).thenReturn(localBlockManagerId) + val buffers = (0 until numMaps).map { mapId => + // Create a ManagedBuffer with the shuffle data. + val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray)) + val managedBuffer = new RecordingManagedBuffer(nioBuffer) + + // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to + // fetch shuffle data. + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) + when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) + .thenAnswer(dummyCompressionFunction) + + managedBuffer + } + + // Make a mocked MapOutputTracker for the shuffle reader to use to determine what + // shuffle data to read. + val mapOutputTracker = mock(classOf[MapOutputTracker]) + when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId)).thenReturn { + // Test a scenario where all data is local, to avoid creating a bunch of additional mocks + // for the code to read data over the network. + val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + (shuffleBlockId, byteOutputStream.size().toLong) + } + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + } + + // Create a mocked shuffle handle to pass into HashShuffleReader. + val shuffleHandle = { + val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) + when(dependency.serializer).thenReturn(Some(serializer)) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + val shuffleReader = new BlockStoreShuffleReader( + shuffleHandle, + reduceId, + reduceId + 1, + TaskContext.empty(), + blockManager, + mapOutputTracker) + + assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) + + // Calling .length above will have exhausted the iterator; make sure that exhausting the + // iterator caused retain and release to be called on each buffer. + buffers.foreach { buffer => + assert(buffer.callsToRetain === 1) + assert(buffer.callsToRelease === 1) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala deleted file mode 100644 index 05b3afef5b..0000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import java.io.{ByteArrayOutputStream, InputStream} -import java.nio.ByteBuffer - -import org.mockito.Matchers.{eq => meq, _} -import org.mockito.Mockito.{mock, when} -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer - -import org.apache.spark._ -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.BaseShuffleHandle -import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} - -/** - * Wrapper for a managed buffer that keeps track of how many times retain and release are called. - * - * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class - * is final (final classes cannot be spied on). - */ -class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer { - var callsToRetain = 0 - var callsToRelease = 0 - - override def size(): Long = underlyingBuffer.size() - override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer() - override def createInputStream(): InputStream = underlyingBuffer.createInputStream() - override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty() - - override def retain(): ManagedBuffer = { - callsToRetain += 1 - underlyingBuffer.retain() - } - override def release(): ManagedBuffer = { - callsToRelease += 1 - underlyingBuffer.release() - } -} - -class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { - - /** - * This test makes sure that, when data is read from a HashShuffleReader, the underlying - * ManagedBuffers that contain the data are eventually released. - */ - test("read() releases resources on completion") { - val testConf = new SparkConf(false) - // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the - // shuffle code calls SparkEnv.get()). - sc = new SparkContext("local", "test", testConf) - - val reduceId = 15 - val shuffleId = 22 - val numMaps = 6 - val keyValuePairsPerMap = 10 - val serializer = new JavaSerializer(testConf) - - // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we - // can ensure retain() and release() are properly called. - val blockManager = mock(classOf[BlockManager]) - - // Create a return function to use for the mocked wrapForCompression method that just returns - // the original input stream. - val dummyCompressionFunction = new Answer[InputStream] { - override def answer(invocation: InvocationOnMock): InputStream = - invocation.getArguments()(1).asInstanceOf[InputStream] - } - - // Create a buffer with some randomly generated key-value pairs to use as the shuffle data - // from each mappers (all mappers return the same shuffle data). - val byteOutputStream = new ByteArrayOutputStream() - val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) - (0 until keyValuePairsPerMap).foreach { i => - serializationStream.writeKey(i) - serializationStream.writeValue(2*i) - } - - // Setup the mocked BlockManager to return RecordingManagedBuffers. - val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) - when(blockManager.blockManagerId).thenReturn(localBlockManagerId) - val buffers = (0 until numMaps).map { mapId => - // Create a ManagedBuffer with the shuffle data. - val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray)) - val managedBuffer = new RecordingManagedBuffer(nioBuffer) - - // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to - // fetch shuffle data. - val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) - when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) - .thenAnswer(dummyCompressionFunction) - - managedBuffer - } - - // Make a mocked MapOutputTracker for the shuffle reader to use to determine what - // shuffle data to read. - val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId)).thenReturn { - // Test a scenario where all data is local, to avoid creating a bunch of additional mocks - // for the code to read data over the network. - val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => - val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - (shuffleBlockId, byteOutputStream.size().toLong) - } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) - } - - // Create a mocked shuffle handle to pass into HashShuffleReader. - val shuffleHandle = { - val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) - when(dependency.serializer).thenReturn(Some(serializer)) - when(dependency.aggregator).thenReturn(None) - when(dependency.keyOrdering).thenReturn(None) - new BaseShuffleHandle(shuffleId, numMaps, dependency) - } - - val shuffleReader = new HashShuffleReader( - shuffleHandle, - reduceId, - reduceId + 1, - TaskContext.empty(), - blockManager, - mapOutputTracker) - - assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) - - // Calling .length above will have exhausted the iterator; make sure that exhausting the - // iterator caused retain and release to be called on each buffer. - buffers.foreach { buffer => - assert(buffer.callsToRetain === 1) - assert(buffer.callsToRelease === 1) - } - } -} -- cgit v1.2.3