From 3de24ae2ed6c58fc96a7e50832afe42fe7af34fb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 23 Mar 2016 10:15:23 -0700 Subject: [SPARK-14075] Refactor MemoryStore to be testable independent of BlockManager This patch refactors the `MemoryStore` so that it can be tested without needing to construct / mock an entire `BlockManager`. - The block manager's serialization- and compression-related methods have been moved from `BlockManager` to `SerializerManager`. - `BlockInfoManager `is now passed directly to classes that need it, rather than being passed via the `BlockManager`. - The `MemoryStore` now calls `dropFromMemory` via a new `BlockEvictionHandler` interface rather than directly calling the `BlockManager`. This change helps to enforce a narrow interface between the `MemoryStore` and `BlockManager` functionality and makes this interface easier to mock in tests. - Several of the block unrolling tests have been moved from `BlockManagerSuite` into a new `MemoryStoreSuite`. Author: Josh Rosen Closes #11899 from JoshRosen/reduce-memorystore-blockmanager-coupling. --- .../apache/spark/unsafe/map/BytesToBytesMap.java | 7 +- .../unsafe/sort/UnsafeExternalSorter.java | 17 +- .../unsafe/sort/UnsafeSorterSpillReader.java | 6 +- .../unsafe/sort/UnsafeSorterSpillWriter.java | 5 +- .../spark/serializer/SerializerManager.scala | 90 +++++- .../spark/shuffle/BlockStoreShuffleReader.scala | 5 +- .../org/apache/spark/storage/BlockManager.scala | 118 ++------ .../spark/storage/BlockManagerManagedBuffer.scala | 6 +- .../apache/spark/storage/memory/MemoryStore.scala | 55 +++- .../util/collection/ExternalAppendOnlyMap.scala | 7 +- .../spark/util/collection/ExternalSorter.scala | 3 +- .../shuffle/sort/UnsafeShuffleWriterSuite.java | 32 +-- .../unsafe/map/AbstractBytesToBytesMapSuite.java | 17 +- .../unsafe/sort/UnsafeExternalSorterSuite.java | 12 +- .../scala/org/apache/spark/DistributedSuite.scala | 3 +- .../shuffle/BlockStoreShuffleReaderSuite.scala | 22 +- .../apache/spark/storage/BlockManagerSuite.scala | 197 +------------- .../apache/spark/storage/MemoryStoreSuite.scala | 302 +++++++++++++++++++++ .../sql/execution/UnsafeExternalRowSorter.java | 1 + .../execution/UnsafeFixedWidthAggregationMap.java | 8 +- .../sql/execution/UnsafeKVExternalSorter.java | 7 +- .../org/apache/spark/sql/execution/Window.scala | 1 + .../execution/datasources/WriterContainer.scala | 1 + .../sql/execution/joins/CartesianProduct.scala | 1 + .../execution/UnsafeKVExternalSorterSuite.scala | 2 +- .../spark/sql/hive/hiveWriterContainers.scala | 1 + .../rdd/WriteAheadLogBackedBlockRDD.scala | 3 +- .../streaming/receiver/ReceivedBlockHandler.scala | 6 +- .../receiver/ReceiverSupervisorImpl.scala | 2 +- .../streaming/ReceivedBlockHandlerSuite.scala | 11 +- .../rdd/WriteAheadLogBackedBlockRDDSuite.scala | 9 +- 31 files changed, 555 insertions(+), 402 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index de36814ecc..9aacb084f6 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -32,6 +32,7 @@ import org.apache.spark.SparkEnv; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -163,12 +164,14 @@ public final class BytesToBytesMap extends MemoryConsumer { private long peakMemoryUsedBytes = 0L; private final BlockManager blockManager; + private final SerializerManager serializerManager; private volatile MapIterator destructiveIterator = null; private LinkedList spillWriters = new LinkedList<>(); public BytesToBytesMap( TaskMemoryManager taskMemoryManager, BlockManager blockManager, + SerializerManager serializerManager, int initialCapacity, double loadFactor, long pageSizeBytes, @@ -176,6 +179,7 @@ public final class BytesToBytesMap extends MemoryConsumer { super(taskMemoryManager, pageSizeBytes); this.taskMemoryManager = taskMemoryManager; this.blockManager = blockManager; + this.serializerManager = serializerManager; this.loadFactor = loadFactor; this.loc = new Location(); this.pageSizeBytes = pageSizeBytes; @@ -209,6 +213,7 @@ public final class BytesToBytesMap extends MemoryConsumer { this( taskMemoryManager, SparkEnv.get() != null ? SparkEnv.get().blockManager() : null, + SparkEnv.get() != null ? SparkEnv.get().serializerManager() : null, initialCapacity, 0.70, pageSizeBytes, @@ -271,7 +276,7 @@ public final class BytesToBytesMap extends MemoryConsumer { } try { Closeables.close(reader, /* swallowIOException = */ false); - reader = spillWriters.getFirst().getReader(blockManager); + reader = spillWriters.getFirst().getReader(serializerManager); recordsInPage = -1; } catch (IOException e) { // Scala iterator does not handle exception diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 927b19c4e8..ded8f0472b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -31,6 +31,7 @@ import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; @@ -51,6 +52,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { private final RecordComparator recordComparator; private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; + private final SerializerManager serializerManager; private final TaskContext taskContext; private ShuffleWriteMetrics writeMetrics; @@ -78,6 +80,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { public static UnsafeExternalSorter createWithExistingInMemorySorter( TaskMemoryManager taskMemoryManager, BlockManager blockManager, + SerializerManager serializerManager, TaskContext taskContext, RecordComparator recordComparator, PrefixComparator prefixComparator, @@ -85,7 +88,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer { long pageSizeBytes, UnsafeInMemorySorter inMemorySorter) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, - taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter); + serializerManager, taskContext, recordComparator, prefixComparator, initialSize, + pageSizeBytes, inMemorySorter); sorter.spill(Long.MAX_VALUE, sorter); // The external sorter will be used to insert records, in-memory sorter is not needed. sorter.inMemSorter = null; @@ -95,18 +99,20 @@ public final class UnsafeExternalSorter extends MemoryConsumer { public static UnsafeExternalSorter create( TaskMemoryManager taskMemoryManager, BlockManager blockManager, + SerializerManager serializerManager, TaskContext taskContext, RecordComparator recordComparator, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes) { - return new UnsafeExternalSorter(taskMemoryManager, blockManager, + return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null); } private UnsafeExternalSorter( TaskMemoryManager taskMemoryManager, BlockManager blockManager, + SerializerManager serializerManager, TaskContext taskContext, RecordComparator recordComparator, PrefixComparator prefixComparator, @@ -116,6 +122,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { super(taskMemoryManager, pageSizeBytes); this.taskMemoryManager = taskMemoryManager; this.blockManager = blockManager; + this.serializerManager = serializerManager; this.taskContext = taskContext; this.recordComparator = recordComparator; this.prefixComparator = prefixComparator; @@ -412,7 +419,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size()); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { - spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager)); + spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager)); } if (inMemSorter != null) { readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); @@ -463,7 +470,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { } spillWriter.close(); spillWriters.add(spillWriter); - nextUpstream = spillWriter.getReader(blockManager); + nextUpstream = spillWriter.getReader(serializerManager); long released = 0L; synchronized (UnsafeExternalSorter.this) { @@ -549,7 +556,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { } else { LinkedList queue = new LinkedList<>(); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { - queue.add(spillWriter.getReader(blockManager)); + queue.add(spillWriter.getReader(serializerManager)); } if (inMemSorter != null) { queue.add(inMemSorter.getSortedIterator()); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 20ee1c8eb0..1d588c37c5 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -22,8 +22,8 @@ import java.io.*; import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; /** @@ -46,13 +46,13 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; public UnsafeSorterSpillReader( - BlockManager blockManager, + SerializerManager serializerManager, File file, BlockId blockId) throws IOException { assert (file.length() > 0); final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); try { - this.in = blockManager.wrapForCompression(blockId, bs); + this.in = serializerManager.wrapForCompression(blockId, bs); this.din = new DataInputStream(this.in); numRecords = numRecordsRemaining = din.readInt(); } catch (IOException e) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 234e21140a..9ba760e842 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -20,6 +20,7 @@ package org.apache.spark.util.collection.unsafe.sort; import java.io.File; import java.io.IOException; +import org.apache.spark.serializer.SerializerManager; import scala.Tuple2; import org.apache.spark.executor.ShuffleWriteMetrics; @@ -144,7 +145,7 @@ public final class UnsafeSorterSpillWriter { return file; } - public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { - return new UnsafeSorterSpillReader(blockManager, file, blockId); + public UnsafeSorterSpillReader getReader(SerializerManager serializerManager) throws IOException { + return new UnsafeSorterSpillReader(serializerManager, file, blockId); } } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index b9f115463a..27e5fa4c2b 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -17,17 +17,25 @@ package org.apache.spark.serializer +import java.io.{BufferedInputStream, BufferedOutputStream, InputStream, OutputStream} +import java.nio.ByteBuffer + import scala.reflect.ClassTag import org.apache.spark.SparkConf +import org.apache.spark.io.CompressionCodec +import org.apache.spark.storage._ +import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer} /** - * Component that selects which [[Serializer]] to use for shuffles. + * Component which configures serialization and compression for various Spark components, including + * automatic selection of which [[Serializer]] to use for shuffles. */ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) { private[this] val kryoSerializer = new KryoSerializer(conf) + private[this] val stringClassTag: ClassTag[String] = implicitly[ClassTag[String]] private[this] val primitiveAndPrimitiveArrayClassTags: Set[ClassTag[_]] = { val primitiveClassTags = Set[ClassTag[_]]( ClassTag.Boolean, @@ -44,7 +52,21 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar primitiveClassTags ++ arrayClassTags } - private[this] val stringClassTag: ClassTag[String] = implicitly[ClassTag[String]] + // Whether to compress broadcast variables that are stored + private[this] val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) + // Whether to compress shuffle output that are stored + private[this] val compressShuffle = conf.getBoolean("spark.shuffle.compress", true) + // Whether to compress RDD partitions that are stored serialized + private[this] val compressRdds = conf.getBoolean("spark.rdd.compress", false) + // Whether to compress shuffle output temporarily spilled to disk + private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) + + /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay + * the initialization of the compression codec until it is first used. The reason is that a Spark + * program could be using a user-defined codec in a third party jar, which is loaded in + * Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been + * loaded yet. */ + private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) private def canUseKryo(ct: ClassTag[_]): Boolean = { primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag @@ -68,4 +90,68 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar defaultSerializer } } + + private def shouldCompress(blockId: BlockId): Boolean = { + blockId match { + case _: ShuffleBlockId => compressShuffle + case _: BroadcastBlockId => compressBroadcast + case _: RDDBlockId => compressRdds + case _: TempLocalBlockId => compressShuffleSpill + case _: TempShuffleBlockId => compressShuffle + case _ => false + } + } + + /** + * Wrap an output stream for compression if block compression is enabled for its block type + */ + def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s + } + + /** + * Wrap an input stream for compression if block compression is enabled for its block type + */ + def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s + } + + /** Serializes into a stream. */ + def dataSerializeStream[T: ClassTag]( + blockId: BlockId, + outputStream: OutputStream, + values: Iterator[T]): Unit = { + val byteStream = new BufferedOutputStream(outputStream) + val ser = getSerializer(implicitly[ClassTag[T]]).newInstance() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() + } + + /** Serializes into a chunked byte buffer. */ + def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = { + val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(1024 * 1024 * 4) + dataSerializeStream(blockId, byteArrayChunkOutputStream, values) + new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)) + } + + /** + * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of + * the iterator is reached. + */ + def dataDeserialize[T: ClassTag](blockId: BlockId, bytes: ChunkedByteBuffer): Iterator[T] = { + dataDeserializeStream[T](blockId, bytes.toInputStream(dispose = true)) + } + + /** + * Deserializes a InputStream into an iterator of values and disposes of it when the end of + * the iterator is reached. + */ + def dataDeserializeStream[T: ClassTag]( + blockId: BlockId, + inputStream: InputStream): Iterator[T] = { + val stream = new BufferedInputStream(inputStream) + getSerializer(implicitly[ClassTag[T]]) + .newInstance() + .deserializeStream(wrapForCompression(blockId, stream)) + .asIterator.asInstanceOf[Iterator[T]] + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 4054465c0f..637b2dfc19 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.serializer.Serializer +import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -33,6 +33,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( startPartition: Int, endPartition: Int, context: TaskContext, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] with Logging { @@ -52,7 +53,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Wrap the streams for compression based on configuration val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => - blockManager.wrapForCompression(blockId, inputStream) + serializerManager.wrapForCompression(blockId, inputStream) } val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 83f8c5c37d..eebb43e245 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -18,7 +18,6 @@ package org.apache.spark.storage import java.io._ -import java.nio.ByteBuffer import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{Await, ExecutionContext, Future} @@ -30,7 +29,6 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.internal.Logging -import org.apache.spark.io.CompressionCodec import org.apache.spark.memory.MemoryManager import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} @@ -38,11 +36,11 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv -import org.apache.spark.serializer.{Serializer, SerializerInstance, SerializerManager} +import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.memory._ import org.apache.spark.util._ -import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer} +import org.apache.spark.util.io.ChunkedByteBuffer /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( @@ -68,7 +66,7 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService, securityManager: SecurityManager, numUsableCores: Int) - extends BlockDataManager with Logging { + extends BlockDataManager with BlockEvictionHandler with Logging { private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) @@ -80,13 +78,15 @@ private[spark] class BlockManager( new DiskBlockManager(conf, deleteFilesOnStop) } + // Visible for testing private[storage] val blockInfoManager = new BlockInfoManager private val futureExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) // Actual storage of where blocks are kept - private[spark] val memoryStore = new MemoryStore(conf, this, memoryManager) + private[spark] val memoryStore = + new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this) private[spark] val diskStore = new DiskStore(conf, diskBlockManager) memoryManager.setMemoryStore(memoryStore) @@ -126,14 +126,6 @@ private[spark] class BlockManager( blockTransferService } - // Whether to compress broadcast variables that are stored - private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) - // Whether to compress shuffle output that are stored - private val compressShuffle = conf.getBoolean("spark.shuffle.compress", true) - // Whether to compress RDD partitions that are stored serialized - private val compressRdds = conf.getBoolean("spark.rdd.compress", false) - // Whether to compress shuffle output temporarily spilled to disk - private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) // Max number of failures before this block manager refreshes the block locations from the driver private val maxFailuresBeforeLocationRefresh = conf.getInt("spark.block.failures.beforeLocationRefresh", 5) @@ -152,13 +144,6 @@ private[spark] class BlockManager( private val peerFetchLock = new Object private var lastPeerFetchTime = 0L - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay - * the initialization of the compression codec until it is first used. The reason is that a Spark - * program could be using a user-defined codec in a third party jar, which is loaded in - * Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been - * loaded yet. */ - private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) - /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -286,7 +271,7 @@ private[spark] class BlockManager( shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { getLocalBytes(blockId) match { - case Some(buffer) => new BlockManagerManagedBuffer(this, blockId, buffer) + case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer) case None => throw new BlockNotFoundException(blockId.toString) } } @@ -422,7 +407,8 @@ private[spark] class BlockManager( val iter: Iterator[Any] = if (level.deserialized) { memoryStore.getValues(blockId).get } else { - dataDeserialize(blockId, memoryStore.getBytes(blockId).get)(info.classTag) + serializerManager.dataDeserialize( + blockId, memoryStore.getBytes(blockId).get)(info.classTag) } val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) @@ -430,11 +416,11 @@ private[spark] class BlockManager( val iterToReturn: Iterator[Any] = { val diskBytes = diskStore.getBytes(blockId) if (level.deserialized) { - val diskValues = dataDeserialize(blockId, diskBytes)(info.classTag) + val diskValues = serializerManager.dataDeserialize(blockId, diskBytes)(info.classTag) maybeCacheDiskValuesInMemory(info, blockId, level, diskValues) } else { val bytes = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes) - dataDeserialize(blockId, bytes)(info.classTag) + serializerManager.dataDeserialize(blockId, bytes)(info.classTag) } } val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId)) @@ -486,7 +472,7 @@ private[spark] class BlockManager( diskStore.getBytes(blockId) } else if (level.useMemory && memoryStore.contains(blockId)) { // The block was not found on disk, so serialize an in-memory copy: - dataSerialize(blockId, memoryStore.getValues(blockId).get) + serializerManager.dataSerialize(blockId, memoryStore.getValues(blockId).get) } else { releaseLock(blockId) throw new SparkException(s"Block $blockId was not found even though it's read-locked") @@ -510,7 +496,8 @@ private[spark] class BlockManager( */ private def getRemoteValues(blockId: BlockId): Option[BlockResult] = { getRemoteBytes(blockId).map { data => - new BlockResult(dataDeserialize(blockId, data), DataReadMethod.Network, data.size) + new BlockResult( + serializerManager.dataDeserialize(blockId, data), DataReadMethod.Network, data.size) } } @@ -699,7 +686,8 @@ private[spark] class BlockManager( serializerInstance: SerializerInstance, bufferSize: Int, writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { - val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) + val compressStream: OutputStream => OutputStream = + serializerManager.wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream, syncWrites, writeMetrics, blockId) @@ -757,7 +745,7 @@ private[spark] class BlockManager( // Put it in memory first, even if it also has useDisk set to true; // We will drop it to disk later if the memory store can't hold it. val putSucceeded = if (level.deserialized) { - val values = dataDeserialize(blockId, bytes)(classTag) + val values = serializerManager.dataDeserialize(blockId, bytes)(classTag) memoryStore.putIterator(blockId, values, level, classTag) match { case Right(_) => true case Left(iter) => @@ -896,7 +884,7 @@ private[spark] class BlockManager( if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") diskStore.put(blockId) { fileOutputStream => - dataSerializeStream(blockId, fileOutputStream, iter)(classTag) + serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag) } size = diskStore.getSize(blockId) } else { @@ -905,7 +893,7 @@ private[spark] class BlockManager( } } else if (level.useDisk) { diskStore.put(blockId) { fileOutputStream => - dataSerializeStream(blockId, fileOutputStream, iterator())(classTag) + serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag) } size = diskStore.getSize(blockId) } @@ -1167,7 +1155,7 @@ private[spark] class BlockManager( * * @return the block's new effective StorageLevel. */ - private[storage] def dropFromMemory[T: ClassTag]( + private[storage] override def dropFromMemory[T: ClassTag]( blockId: BlockId, data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = { logInfo(s"Dropping block $blockId from memory") @@ -1181,7 +1169,7 @@ private[spark] class BlockManager( data() match { case Left(elements) => diskStore.put(blockId) { fileOutputStream => - dataSerializeStream( + serializerManager.dataSerializeStream( blockId, fileOutputStream, elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]]) @@ -1264,70 +1252,6 @@ private[spark] class BlockManager( } } - private def shouldCompress(blockId: BlockId): Boolean = { - blockId match { - case _: ShuffleBlockId => compressShuffle - case _: BroadcastBlockId => compressBroadcast - case _: RDDBlockId => compressRdds - case _: TempLocalBlockId => compressShuffleSpill - case _: TempShuffleBlockId => compressShuffle - case _ => false - } - } - - /** - * Wrap an output stream for compression if block compression is enabled for its block type - */ - def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { - if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s - } - - /** - * Wrap an input stream for compression if block compression is enabled for its block type - */ - def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { - if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s - } - - /** Serializes into a stream. */ - def dataSerializeStream[T: ClassTag]( - blockId: BlockId, - outputStream: OutputStream, - values: Iterator[T]): Unit = { - val byteStream = new BufferedOutputStream(outputStream) - val ser = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance() - ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() - } - - /** Serializes into a chunked byte buffer. */ - def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = { - val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(1024 * 1024 * 4) - dataSerializeStream(blockId, byteArrayChunkOutputStream, values) - new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)) - } - - /** - * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of - * the iterator is reached. - */ - def dataDeserialize[T: ClassTag](blockId: BlockId, bytes: ChunkedByteBuffer): Iterator[T] = { - dataDeserializeStream[T](blockId, bytes.toInputStream(dispose = true)) - } - - /** - * Deserializes a InputStream into an iterator of values and disposes of it when the end of - * the iterator is reached. - */ - def dataDeserializeStream[T: ClassTag]( - blockId: BlockId, - inputStream: InputStream): Iterator[T] = { - val stream = new BufferedInputStream(inputStream) - serializerManager.getSerializer(implicitly[ClassTag[T]]) - .newInstance() - .deserializeStream(wrapForCompression(blockId, stream)) - .asIterator.asInstanceOf[Iterator[T]] - } - def stop(): Unit = { blockTransferService.close() if (shuffleClient ne blockTransferService) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala index 12594e6a2b..f66f942798 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -29,19 +29,19 @@ import org.apache.spark.util.io.ChunkedByteBuffer * to the network layer's notion of retain / release counts. */ private[storage] class BlockManagerManagedBuffer( - blockManager: BlockManager, + blockInfoManager: BlockInfoManager, blockId: BlockId, chunkedBuffer: ChunkedByteBuffer) extends NettyManagedBuffer(chunkedBuffer.toNetty) { override def retain(): ManagedBuffer = { super.retain() - val locked = blockManager.blockInfoManager.lockForReading(blockId, blocking = false) + val locked = blockInfoManager.lockForReading(blockId, blocking = false) assert(locked.isDefined) this } override def release(): ManagedBuffer = { - blockManager.releaseLock(blockId) + blockInfoManager.unlock(blockId) super.release() } } diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index d370ee912a..90016cbeb8 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -26,7 +26,8 @@ import scala.reflect.ClassTag import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.memory.MemoryManager -import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel} +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel} import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector import org.apache.spark.util.io.ChunkedByteBuffer @@ -44,14 +45,33 @@ private case class SerializedMemoryEntry[T]( size: Long, classTag: ClassTag[T]) extends MemoryEntry[T] +private[storage] trait BlockEvictionHandler { + /** + * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory + * store reaches its limit and needs to free up space. + * + * If `data` is not put on disk, it won't be created. + * + * The caller of this method must hold a write lock on the block before calling this method. + * This method does not release the write lock. + * + * @return the block's new effective StorageLevel. + */ + private[storage] def dropFromMemory[T: ClassTag]( + blockId: BlockId, + data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel +} + /** * Stores blocks in memory, either as Arrays of deserialized Java objects or as * serialized ByteBuffers. */ private[spark] class MemoryStore( conf: SparkConf, - blockManager: BlockManager, - memoryManager: MemoryManager) + blockInfoManager: BlockInfoManager, + serializerManager: SerializerManager, + memoryManager: MemoryManager, + blockEvictionHandler: BlockEvictionHandler) extends Logging { // Note: all changes to memory allocations, notably putting blocks, evicting blocks, and @@ -117,7 +137,7 @@ private[spark] class MemoryStore( entries.put(blockId, entry) } logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed))) + blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) true } else { false @@ -201,7 +221,7 @@ private[spark] class MemoryStore( val entry = if (level.deserialized) { new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag) } else { - val bytes = blockManager.dataSerialize(blockId, arrayValues.iterator)(classTag) + val bytes = serializerManager.dataSerialize(blockId, arrayValues.iterator)(classTag) new SerializedMemoryEntry[T](bytes, bytes.size, classTag) } val size = entry.size @@ -237,7 +257,10 @@ private[spark] class MemoryStore( } val bytesOrValues = if (level.deserialized) "values" else "bytes" logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format( - blockId, bytesOrValues, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed))) + blockId, + bytesOrValues, + Utils.bytesToString(size), + Utils.bytesToString(maxMemory - blocksMemoryUsed))) Right(size) } else { assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask, @@ -284,7 +307,7 @@ private[spark] class MemoryStore( } if (entry != null) { memoryManager.releaseStorageMemory(entry.size) - logDebug(s"Block $blockId of size ${entry.size} dropped " + + logInfo(s"Block $blockId of size ${entry.size} dropped " + s"from memory (free ${maxMemory - blocksMemoryUsed})") true } else { @@ -339,7 +362,7 @@ private[spark] class MemoryStore( // We don't want to evict blocks which are currently being read, so we need to obtain // an exclusive write lock on blocks which are candidates for eviction. We perform a // non-blocking "tryLock" here in order to ignore blocks which are locked for reading: - if (blockManager.blockInfoManager.lockForWriting(blockId, blocking = false).isDefined) { + if (blockInfoManager.lockForWriting(blockId, blocking = false).isDefined) { selectedBlocks += blockId freedMemory += pair.getValue.size } @@ -353,20 +376,21 @@ private[spark] class MemoryStore( case SerializedMemoryEntry(buffer, _, _) => Right(buffer) } val newEffectiveStorageLevel = - blockManager.dropFromMemory(blockId, () => data)(entry.classTag) + blockEvictionHandler.dropFromMemory(blockId, () => data)(entry.classTag) if (newEffectiveStorageLevel.isValid) { // The block is still present in at least one store, so release the lock // but don't delete the block info - blockManager.releaseLock(blockId) + blockInfoManager.unlock(blockId) } else { // The block isn't present in any store, so delete the block info so that the // block can be stored again - blockManager.blockInfoManager.removeBlock(blockId) + blockInfoManager.removeBlock(blockId) } } if (freedMemory >= space) { - logInfo(s"${selectedBlocks.size} blocks selected for dropping") + logInfo(s"${selectedBlocks.size} blocks selected for dropping " + + s"(${Utils.bytesToString(freedMemory)} bytes)") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } // This should never be null as only one task should be dropping @@ -376,14 +400,15 @@ private[spark] class MemoryStore( dropBlock(blockId, entry) } } + logInfo(s"After dropping ${selectedBlocks.size} blocks, " + + s"free memory is ${Utils.bytesToString(maxMemory - blocksMemoryUsed)}") freedMemory } else { blockId.foreach { id => - logInfo(s"Will not store $id as it would require dropping another block " + - "from the same RDD") + logInfo(s"Will not store $id") } selectedBlocks.foreach { id => - blockManager.releaseLock(id) + blockInfoManager.unlock(id) } 0L } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 531f1c4dd2..95351e9826 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -31,7 +31,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager -import org.apache.spark.serializer.{DeserializationStream, Serializer} +import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerManager} import org.apache.spark.storage.{BlockId, BlockManager} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator @@ -59,7 +59,8 @@ class ExternalAppendOnlyMap[K, V, C]( mergeCombiners: (C, C) => C, serializer: Serializer = SparkEnv.get.serializer, blockManager: BlockManager = SparkEnv.get.blockManager, - context: TaskContext = TaskContext.get()) + context: TaskContext = TaskContext.get(), + serializerManager: SerializerManager = SparkEnv.get.serializerManager) extends Iterable[(K, C)] with Serializable with Logging @@ -458,7 +459,7 @@ class ExternalAppendOnlyMap[K, V, C]( ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream) + val compressedStream = serializerManager.wrapForCompression(blockId, bufferedStream) ser.deserializeStream(compressedStream) } else { // No more batches left diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 8cdc4663e6..561ba22df5 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -108,6 +108,7 @@ private[spark] class ExternalSorter[K, V, C]( private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager + private val serializerManager = SparkEnv.get.serializerManager private val serInstance = serializer.newInstance() // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided @@ -503,7 +504,7 @@ private[spark] class ExternalSorter[K, V, C]( ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val compressedStream = blockManager.wrapForCompression(spill.blockId, bufferedStream) + val compressedStream = serializerManager.wrapForCompression(spill.blockId, bufferedStream) serInstance.deserializeStream(compressedStream) } else { // No more batches left diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 47c695ad4e..44733dcdaf 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -70,6 +70,7 @@ public class UnsafeShuffleWriterSuite { final LinkedList spillFilesCreated = new LinkedList<>(); SparkConf conf; final Serializer serializer = new KryoSerializer(new SparkConf()); + final SerializerManager serializerManager = new SerializerManager(serializer, new SparkConf()); TaskMetrics taskMetrics; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @@ -111,7 +112,7 @@ public class UnsafeShuffleWriterSuite { .set("spark.memory.offHeap.enabled", "false"); taskMetrics = new TaskMetrics(); memoryManager = new TestMemoryManager(conf); - taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + taskMemoryManager = new TaskMemoryManager(memoryManager, 0); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( @@ -135,35 +136,6 @@ public class UnsafeShuffleWriterSuite { ); } }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer( - new Answer() { - @Override - public InputStream answer(InvocationOnMock invocation) throws Throwable { - assertTrue(invocation.getArguments()[0] instanceof TempShuffleBlockId); - InputStream is = (InputStream) invocation.getArguments()[1]; - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is); - } else { - return is; - } - } - } - ); - - when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer( - new Answer() { - @Override - public OutputStream answer(InvocationOnMock invocation) throws Throwable { - assertTrue(invocation.getArguments()[0] instanceof TempShuffleBlockId); - OutputStream os = (OutputStream) invocation.getArguments()[1]; - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os); - } else { - return os; - } - } - } - ); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); doAnswer(new Answer() { diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 6667179b9d..449fb45c30 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -19,7 +19,6 @@ package org.apache.spark.unsafe.map; import java.io.File; import java.io.IOException; -import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.*; @@ -42,7 +41,9 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.serializer.JavaSerializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -51,7 +52,6 @@ import org.apache.spark.util.Utils; import static org.hamcrest.Matchers.greaterThan; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.mockito.AdditionalAnswers.returnsSecondArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; @@ -64,6 +64,9 @@ public abstract class AbstractBytesToBytesMapSuite { private TestMemoryManager memoryManager; private TaskMemoryManager taskMemoryManager; + private SerializerManager serializerManager = new SerializerManager( + new JavaSerializer(new SparkConf()), + new SparkConf().set("spark.shuffle.spill.compress", "false")); private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes final LinkedList spillFilesCreated = new LinkedList<>(); @@ -85,7 +88,9 @@ public abstract class AbstractBytesToBytesMapSuite { new TestMemoryManager( new SparkConf() .set("spark.memory.offHeap.enabled", "" + useOffHeapMemoryAllocator()) - .set("spark.memory.offHeap.size", "256mb")); + .set("spark.memory.offHeap.size", "256mb") + .set("spark.shuffle.spill.compress", "false") + .set("spark.shuffle.compress", "false")); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); @@ -124,8 +129,6 @@ public abstract class AbstractBytesToBytesMapSuite { ); } }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) - .then(returnsSecondArg()); } @After @@ -546,8 +549,8 @@ public abstract class AbstractBytesToBytesMapSuite { @Test public void spillInIterator() throws IOException { - BytesToBytesMap map = - new BytesToBytesMap(taskMemoryManager, blockManager, 1, 0.75, 1024, false); + BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024, false); try { int i; for (i = 0; i < 1024; i++) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index db50e551f2..a2253d8559 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -19,7 +19,6 @@ package org.apache.spark.util.collection.unsafe.sort; import java.io.File; import java.io.IOException; -import java.io.InputStream; import java.io.OutputStream; import java.util.Arrays; import java.util.LinkedList; @@ -43,14 +42,15 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.JavaSerializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.junit.Assert.*; -import static org.mockito.AdditionalAnswers.returnsSecondArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; @@ -60,6 +60,9 @@ public class UnsafeExternalSorterSuite { final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + final SerializerManager serializerManager = new SerializerManager( + new JavaSerializer(new SparkConf()), + new SparkConf().set("spark.shuffle.spill.compress", "false")); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = new PrefixComparator() { @Override @@ -135,8 +138,6 @@ public class UnsafeExternalSorterSuite { ); } }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) - .then(returnsSecondArg()); } @After @@ -172,6 +173,7 @@ public class UnsafeExternalSorterSuite { return UnsafeExternalSorter.create( taskMemoryManager, blockManager, + serializerManager, taskContext, recordComparator, prefixComparator, @@ -374,6 +376,7 @@ public class UnsafeExternalSorterSuite { final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( taskMemoryManager, blockManager, + serializerManager, taskContext, null, null, @@ -408,6 +411,7 @@ public class UnsafeExternalSorterSuite { final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( taskMemoryManager, blockManager, + serializerManager, taskContext, recordComparator, prefixComparator, diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 2732cd6749..3dded4d486 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -194,10 +194,11 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val blockId = blockIds(0) val blockManager = SparkEnv.get.blockManager val blockTransfer = SparkEnv.get.blockTransferService + val serializerManager = SparkEnv.get.serializerManager blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, blockId.toString) - val deserialized = blockManager.dataDeserialize[Int](blockId, + val deserialized = serializerManager.dataDeserialize[Int](blockId, new ChunkedByteBuffer(bytes.nioByteBuffer())).toList assert(deserialized === (1 to 100).toList) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 08f52c92e1..dba1172d5f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -20,14 +20,11 @@ package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer -import org.mockito.Matchers.{eq => meq, _} import org.mockito.Mockito.{mock, when} -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer import org.apache.spark._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} /** @@ -77,13 +74,6 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // can ensure retain() and release() are properly called. val blockManager = mock(classOf[BlockManager]) - // Create a return function to use for the mocked wrapForCompression method that just returns - // the original input stream. - val dummyCompressionFunction = new Answer[InputStream] { - override def answer(invocation: InvocationOnMock): InputStream = - invocation.getArguments()(1).asInstanceOf[InputStream] - } - // Create a buffer with some randomly generated key-value pairs to use as the shuffle data // from each mappers (all mappers return the same shuffle data). val byteOutputStream = new ByteArrayOutputStream() @@ -105,9 +95,6 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // fetch shuffle data. val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) - when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) - .thenAnswer(dummyCompressionFunction) - managedBuffer } @@ -133,11 +120,18 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext new BaseShuffleHandle(shuffleId, numMaps, dependency) } + val serializerManager = new SerializerManager( + serializer, + new SparkConf() + .set("spark.shuffle.compress", "false") + .set("spark.shuffle.spill.compress", "false")) + val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, TaskContext.empty(), + serializerManager, blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 9419dfaa00..94f6f87740 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1033,138 +1033,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store") } - test("reserve/release unroll memory") { - store = makeBlockManager(12000) - val memoryStore = store.memoryStore - assert(memoryStore.currentUnrollMemory === 0) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { - memoryStore.reserveUnrollMemoryForThisTask(TestBlockId(""), memory) - } - - // Reserve - assert(reserveUnrollMemoryForThisTask(100)) - assert(memoryStore.currentUnrollMemoryForThisTask === 100) - assert(reserveUnrollMemoryForThisTask(200)) - assert(memoryStore.currentUnrollMemoryForThisTask === 300) - assert(reserveUnrollMemoryForThisTask(500)) - assert(memoryStore.currentUnrollMemoryForThisTask === 800) - assert(!reserveUnrollMemoryForThisTask(1000000)) - assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted - // Release - memoryStore.releaseUnrollMemoryForThisTask(100) - assert(memoryStore.currentUnrollMemoryForThisTask === 700) - memoryStore.releaseUnrollMemoryForThisTask(100) - assert(memoryStore.currentUnrollMemoryForThisTask === 600) - // Reserve again - assert(reserveUnrollMemoryForThisTask(4400)) - assert(memoryStore.currentUnrollMemoryForThisTask === 5000) - assert(!reserveUnrollMemoryForThisTask(20000)) - assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted - // Release again - memoryStore.releaseUnrollMemoryForThisTask(1000) - assert(memoryStore.currentUnrollMemoryForThisTask === 4000) - memoryStore.releaseUnrollMemoryForThisTask() // release all - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - } - - test("safely unroll blocks") { - store = makeBlockManager(12000) - val smallList = List.fill(40)(new Array[Byte](100)) - val bigList = List.fill(40)(new Array[Byte](1000)) - val memoryStore = store.memoryStore - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - // Unroll with all the space in the world. This should succeed. - var putResult = - memoryStore.putIterator("unroll", smallList.iterator, StorageLevel.MEMORY_ONLY, ClassTag.Any) - assert(putResult.isRight) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => - assert(e === a, "getValues() did not return original values!") - } - assert(memoryStore.remove("unroll")) - - // Unroll with not enough space. This should succeed after kicking out someBlock1. - assert(store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY)) - assert(store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY)) - putResult = - memoryStore.putIterator("unroll", smallList.iterator, StorageLevel.MEMORY_ONLY, ClassTag.Any) - assert(putResult.isRight) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - assert(memoryStore.contains("someBlock2")) - assert(!memoryStore.contains("someBlock1")) - smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => - assert(e === a, "getValues() did not return original values!") - } - assert(memoryStore.remove("unroll")) - - // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = - // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. - // In the mean time, however, we kicked out someBlock2 before giving up. - assert(store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY)) - putResult = - memoryStore.putIterator("unroll", bigList.iterator, StorageLevel.MEMORY_ONLY, ClassTag.Any) - assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator - assert(!memoryStore.contains("someBlock2")) - assert(putResult.isLeft) - bigList.iterator.zip(putResult.left.get).foreach { case (e, a) => - assert(e === a, "putIterator() did not return original values!") - } - // The unroll memory was freed once the iterator returned by putIterator() was fully traversed. - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - } - - test("safely unroll blocks through putIterator") { - store = makeBlockManager(12000) - val memOnly = StorageLevel.MEMORY_ONLY - val memoryStore = store.memoryStore - val smallList = List.fill(40)(new Array[Byte](100)) - val bigList = List.fill(40)(new Array[Byte](1000)) - def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] - def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - // Unroll with plenty of space. This should succeed and cache both blocks. - val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, ClassTag.Any) - val result2 = memoryStore.putIterator("b2", smallIterator, memOnly, ClassTag.Any) - assert(memoryStore.contains("b1")) - assert(memoryStore.contains("b2")) - assert(result1.isRight) // unroll was successful - assert(result2.isRight) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - // Re-put these two blocks so block manager knows about them too. Otherwise, block manager - // would not know how to drop them from memory later. - memoryStore.remove("b1") - memoryStore.remove("b2") - store.putIterator("b1", smallIterator, memOnly) - store.putIterator("b2", smallIterator, memOnly) - - // Unroll with not enough space. This should succeed but kick out b1 in the process. - val result3 = memoryStore.putIterator("b3", smallIterator, memOnly, ClassTag.Any) - assert(result3.isRight) - assert(!memoryStore.contains("b1")) - assert(memoryStore.contains("b2")) - assert(memoryStore.contains("b3")) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - memoryStore.remove("b3") - store.putIterator("b3", smallIterator, memOnly) - - // Unroll huge block with not enough space. This should fail and kick out b2 in the process. - val result4 = memoryStore.putIterator("b4", bigIterator, memOnly, ClassTag.Any) - assert(result4.isLeft) // unroll was unsuccessful - assert(!memoryStore.contains("b1")) - assert(!memoryStore.contains("b2")) - assert(memoryStore.contains("b3")) - assert(!memoryStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator - } - - /** - * This test is essentially identical to the preceding one, except that it uses MEMORY_AND_DISK. - */ test("safely unroll blocks through putIterator (disk)") { store = makeBlockManager(12000) val memAndDisk = StorageLevel.MEMORY_AND_DISK @@ -1203,72 +1071,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!memoryStore.contains("b2")) assert(memoryStore.contains("b3")) assert(!memoryStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator - result4.left.get.close() - assert(memoryStore.currentUnrollMemoryForThisTask === 0) // close released the unroll memory - } - - test("multiple unrolls by the same thread") { - store = makeBlockManager(12000) - val memOnly = StorageLevel.MEMORY_ONLY - val memoryStore = store.memoryStore - val smallList = List.fill(40)(new Array[Byte](100)) - def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - // All unroll memory used is released because putIterator did not return an iterator - assert(memoryStore.putIterator("b1", smallIterator, memOnly, ClassTag.Any).isRight) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - assert(memoryStore.putIterator("b2", smallIterator, memOnly, ClassTag.Any).isRight) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - // Unroll memory is not released because putIterator returned an iterator - // that still depends on the underlying vector used in the process - assert(memoryStore.putIterator("b3", smallIterator, memOnly, ClassTag.Any).isLeft) - val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask - assert(unrollMemoryAfterB3 > 0) - - // The unroll memory owned by this thread builds on top of its value after the previous unrolls - assert(memoryStore.putIterator("b4", smallIterator, memOnly, ClassTag.Any).isLeft) - val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask - assert(unrollMemoryAfterB4 > unrollMemoryAfterB3) - - // ... but only to a certain extent (until we run out of free space to grant new unroll memory) - assert(memoryStore.putIterator("b5", smallIterator, memOnly, ClassTag.Any).isLeft) - val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask - assert(memoryStore.putIterator("b6", smallIterator, memOnly, ClassTag.Any).isLeft) - val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask - assert(memoryStore.putIterator("b7", smallIterator, memOnly, ClassTag.Any).isLeft) - val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask - assert(unrollMemoryAfterB5 === unrollMemoryAfterB4) - assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) - assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) - } - - test("lazily create a big ByteBuffer to avoid OOM if it cannot be put into MemoryStore") { - store = makeBlockManager(12000) - val memoryStore = store.memoryStore - val blockId = BlockId("rdd_3_10") - store.blockInfoManager.lockNewBlockForWriting( - blockId, new BlockInfo(StorageLevel.MEMORY_ONLY, ClassTag.Any, tellMaster = false)) - memoryStore.putBytes(blockId, 13000, () => { - fail("A big ByteBuffer that cannot be put into MemoryStore should not be created") - }) - } - - test("put a small ByteBuffer to MemoryStore") { - store = makeBlockManager(12000) - val memoryStore = store.memoryStore - val blockId = BlockId("rdd_3_10") - var bytes: ChunkedByteBuffer = null - memoryStore.putBytes(blockId, 10000, () => { - bytes = new ChunkedByteBuffer(ByteBuffer.allocate(10000)) - bytes - }) - assert(memoryStore.getSize(blockId) === 10000) } - test("read-locked blocks cannot be evicted from the MemoryStore") { + test("read-locked blocks cannot be evicted from memory") { store = makeBlockManager(12000) val arr = new Array[Byte](4000) // First store a1 and a2, both in memory, and a3, on disk only diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala new file mode 100644 index 0000000000..b4ab67ca15 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +import scala.language.implicitConversions +import scala.language.postfixOps +import scala.language.reflectiveCalls +import scala.reflect.ClassTag + +import org.scalatest._ + +import org.apache.spark._ +import org.apache.spark.memory.StaticMemoryManager +import org.apache.spark.serializer.{KryoSerializer, SerializerManager} +import org.apache.spark.storage.memory.{BlockEvictionHandler, MemoryStore, PartiallyUnrolledIterator} +import org.apache.spark.util._ +import org.apache.spark.util.io.ChunkedByteBuffer + +class MemoryStoreSuite + extends SparkFunSuite + with PrivateMethodTester + with BeforeAndAfterEach + with ResetSystemProperties { + + var conf: SparkConf = new SparkConf(false) + .set("spark.test.useCompressedOops", "true") + .set("spark.storage.unrollFraction", "0.4") + .set("spark.storage.unrollMemoryThreshold", "512") + + // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test + val serializer = new KryoSerializer(new SparkConf(false).set("spark.kryoserializer.buffer", "1m")) + + // Implicitly convert strings to BlockIds for test clarity. + implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + def rdd(rddId: Int, splitId: Int): RDDBlockId = RDDBlockId(rddId, splitId) + + override def beforeEach(): Unit = { + super.beforeEach() + // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case + System.setProperty("os.arch", "amd64") + val initialize = PrivateMethod[Unit]('initialize) + SizeEstimator invokePrivate initialize() + } + + def makeMemoryStore(maxMem: Long): (MemoryStore, BlockInfoManager) = { + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) + val serializerManager = new SerializerManager(serializer, conf) + val blockInfoManager = new BlockInfoManager + val blockEvictionHandler = new BlockEvictionHandler { + var memoryStore: MemoryStore = _ + override private[storage] def dropFromMemory[T: ClassTag]( + blockId: BlockId, + data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = { + memoryStore.remove(blockId) + StorageLevel.NONE + } + } + val memoryStore = + new MemoryStore(conf, blockInfoManager, serializerManager, memManager, blockEvictionHandler) + memManager.setMemoryStore(memoryStore) + blockEvictionHandler.memoryStore = memoryStore + (memoryStore, blockInfoManager) + } + + test("reserve/release unroll memory") { + val (memoryStore, _) = makeMemoryStore(12000) + assert(memoryStore.currentUnrollMemory === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { + memoryStore.reserveUnrollMemoryForThisTask(TestBlockId(""), memory) + } + + // Reserve + assert(reserveUnrollMemoryForThisTask(100)) + assert(memoryStore.currentUnrollMemoryForThisTask === 100) + assert(reserveUnrollMemoryForThisTask(200)) + assert(memoryStore.currentUnrollMemoryForThisTask === 300) + assert(reserveUnrollMemoryForThisTask(500)) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) + assert(!reserveUnrollMemoryForThisTask(1000000)) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted + // Release + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 700) + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 600) + // Reserve again + assert(reserveUnrollMemoryForThisTask(4400)) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) + assert(!reserveUnrollMemoryForThisTask(20000)) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted + // Release again + memoryStore.releaseUnrollMemoryForThisTask(1000) + assert(memoryStore.currentUnrollMemoryForThisTask === 4000) + memoryStore.releaseUnrollMemoryForThisTask() // release all + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + } + + test("safely unroll blocks") { + val smallList = List.fill(40)(new Array[Byte](100)) + val bigList = List.fill(40)(new Array[Byte](1000)) + val ct = implicitly[ClassTag[Array[Byte]]] + val (memoryStore, blockInfoManager) = makeMemoryStore(12000) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + def putIterator[T]( + blockId: BlockId, + iter: Iterator[T], + classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { + assert(blockInfoManager.lockNewBlockForWriting( + blockId, + new BlockInfo(StorageLevel.MEMORY_ONLY, classTag, tellMaster = false))) + val res = memoryStore.putIterator(blockId, iter, StorageLevel.MEMORY_ONLY, classTag) + blockInfoManager.unlock(blockId) + res + } + + // Unroll with all the space in the world. This should succeed. + var putResult = putIterator("unroll", smallList.iterator, ClassTag.Any) + assert(putResult.isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => + assert(e === a, "getValues() did not return original values!") + } + blockInfoManager.lockForWriting("unroll") + assert(memoryStore.remove("unroll")) + blockInfoManager.removeBlock("unroll") + + // Unroll with not enough space. This should succeed after kicking out someBlock1. + assert(putIterator("someBlock1", smallList.iterator, ct).isRight) + assert(putIterator("someBlock2", smallList.iterator, ct).isRight) + putResult = putIterator("unroll", smallList.iterator, ClassTag.Any) + assert(putResult.isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + assert(memoryStore.contains("someBlock2")) + assert(!memoryStore.contains("someBlock1")) + smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => + assert(e === a, "getValues() did not return original values!") + } + blockInfoManager.lockForWriting("unroll") + assert(memoryStore.remove("unroll")) + blockInfoManager.removeBlock("unroll") + + // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = + // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. + // In the meantime, however, we kicked out someBlock2 before giving up. + assert(putIterator("someBlock3", smallList.iterator, ct).isRight) + putResult = putIterator("unroll", bigList.iterator, ClassTag.Any) + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator + assert(!memoryStore.contains("someBlock2")) + assert(putResult.isLeft) + bigList.iterator.zip(putResult.left.get).foreach { case (e, a) => + assert(e === a, "putIterator() did not return original values!") + } + // The unroll memory was freed once the iterator returned by putIterator() was fully traversed. + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + } + + test("safely unroll blocks through putIterator") { + val (memoryStore, blockInfoManager) = makeMemoryStore(12000) + val smallList = List.fill(40)(new Array[Byte](100)) + val bigList = List.fill(40)(new Array[Byte](1000)) + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + def putIterator[T]( + blockId: BlockId, + iter: Iterator[T], + classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { + assert(blockInfoManager.lockNewBlockForWriting( + blockId, + new BlockInfo(StorageLevel.MEMORY_ONLY, classTag, tellMaster = false))) + val res = memoryStore.putIterator(blockId, iter, StorageLevel.MEMORY_ONLY, classTag) + blockInfoManager.unlock(blockId) + res + } + + // Unroll with plenty of space. This should succeed and cache both blocks. + val result1 = putIterator("b1", smallIterator, ClassTag.Any) + val result2 = putIterator("b2", smallIterator, ClassTag.Any) + assert(memoryStore.contains("b1")) + assert(memoryStore.contains("b2")) + assert(result1.isRight) // unroll was successful + assert(result2.isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + // Re-put these two blocks so block manager knows about them too. Otherwise, block manager + // would not know how to drop them from memory later. + blockInfoManager.lockForWriting("b1") + memoryStore.remove("b1") + blockInfoManager.removeBlock("b1") + blockInfoManager.lockForWriting("b2") + memoryStore.remove("b2") + blockInfoManager.removeBlock("b2") + putIterator("b1", smallIterator, ClassTag.Any) + putIterator("b2", smallIterator, ClassTag.Any) + + // Unroll with not enough space. This should succeed but kick out b1 in the process. + val result3 = putIterator("b3", smallIterator, ClassTag.Any) + assert(result3.isRight) + assert(!memoryStore.contains("b1")) + assert(memoryStore.contains("b2")) + assert(memoryStore.contains("b3")) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + blockInfoManager.lockForWriting("b3") + assert(memoryStore.remove("b3")) + blockInfoManager.removeBlock("b3") + putIterator("b3", smallIterator, ClassTag.Any) + + // Unroll huge block with not enough space. This should fail and kick out b2 in the process. + val result4 = putIterator("b4", bigIterator, ClassTag.Any) + assert(result4.isLeft) // unroll was unsuccessful + assert(!memoryStore.contains("b1")) + assert(!memoryStore.contains("b2")) + assert(memoryStore.contains("b3")) + assert(!memoryStore.contains("b4")) + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator + result4.left.get.close() + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // close released the unroll memory + } + + test("multiple unrolls by the same thread") { + val (memoryStore, _) = makeMemoryStore(12000) + val smallList = List.fill(40)(new Array[Byte](100)) + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + def putIterator( + blockId: BlockId, + iter: Iterator[Any]): Either[PartiallyUnrolledIterator[Any], Long] = { + memoryStore.putIterator(blockId, iter, StorageLevel.MEMORY_ONLY, ClassTag.Any) + } + + // All unroll memory used is released because putIterator did not return an iterator + assert(putIterator("b1", smallIterator).isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + assert(putIterator("b2", smallIterator).isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + // Unroll memory is not released because putIterator returned an iterator + // that still depends on the underlying vector used in the process + assert(putIterator("b3", smallIterator).isLeft) + val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask + assert(unrollMemoryAfterB3 > 0) + + // The unroll memory owned by this thread builds on top of its value after the previous unrolls + assert(putIterator("b4", smallIterator).isLeft) + val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask + assert(unrollMemoryAfterB4 > unrollMemoryAfterB3) + + // ... but only to a certain extent (until we run out of free space to grant new unroll memory) + assert(putIterator("b5", smallIterator).isLeft) + val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask + assert(putIterator("b6", smallIterator).isLeft) + val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask + assert(putIterator("b7", smallIterator).isLeft) + val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask + assert(unrollMemoryAfterB5 === unrollMemoryAfterB4) + assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) + assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) + } + + test("lazily create a big ByteBuffer to avoid OOM if it cannot be put into MemoryStore") { + val (memoryStore, blockInfoManager) = makeMemoryStore(12000) + val blockId = BlockId("rdd_3_10") + blockInfoManager.lockNewBlockForWriting( + blockId, new BlockInfo(StorageLevel.MEMORY_ONLY, ClassTag.Any, tellMaster = false)) + memoryStore.putBytes(blockId, 13000, () => { + fail("A big ByteBuffer that cannot be put into MemoryStore should not be created") + }) + } + + test("put a small ByteBuffer to MemoryStore") { + val (memoryStore, _) = makeMemoryStore(12000) + val blockId = BlockId("rdd_3_10") + var bytes: ChunkedByteBuffer = null + memoryStore.putBytes(blockId, 10000, () => { + bytes = new ChunkedByteBuffer(ByteBuffer.allocate(10000)) + bytes + }) + assert(memoryStore.getSize(blockId) === 10000) + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index d85147e961..aa7fc2121e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -67,6 +67,7 @@ public final class UnsafeExternalRowSorter { sorter = UnsafeExternalSorter.create( taskContext.taskMemoryManager(), sparkEnv.blockManager(), + sparkEnv.serializerManager(), taskContext, new RowComparator(ordering, schema.length()), prefixComparator, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index acf6c583bb..8882903bbf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -241,7 +241,11 @@ public final class UnsafeFixedWidthAggregationMap { */ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException { return new UnsafeKVExternalSorter( - groupingKeySchema, aggregationBufferSchema, - SparkEnv.get().blockManager(), map.getPageSizeBytes(), map); + groupingKeySchema, + aggregationBufferSchema, + SparkEnv.get().blockManager(), + SparkEnv.get().serializerManager(), + map.getPageSizeBytes(), + map); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 9e08675c3e..d3bfb00b3f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -24,6 +24,7 @@ import com.google.common.annotations.VisibleForTesting; import org.apache.spark.TaskContext; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering; import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering; @@ -52,14 +53,16 @@ public final class UnsafeKVExternalSorter { StructType keySchema, StructType valueSchema, BlockManager blockManager, + SerializerManager serializerManager, long pageSizeBytes) throws IOException { - this(keySchema, valueSchema, blockManager, pageSizeBytes, null); + this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, null); } public UnsafeKVExternalSorter( StructType keySchema, StructType valueSchema, BlockManager blockManager, + SerializerManager serializerManager, long pageSizeBytes, @Nullable BytesToBytesMap map) throws IOException { this.keySchema = keySchema; @@ -77,6 +80,7 @@ public final class UnsafeKVExternalSorter { sorter = UnsafeExternalSorter.create( taskMemoryManager, blockManager, + serializerManager, taskContext, recordComparator, prefixComparator, @@ -116,6 +120,7 @@ public final class UnsafeKVExternalSorter { sorter = UnsafeExternalSorter.createWithExistingInMemorySorter( taskMemoryManager, blockManager, + serializerManager, taskContext, new KVComparator(ordering, keySchema.length()), prefixComparator, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index a4c0e1c9fb..270c09aff3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -339,6 +339,7 @@ case class Window( sorter = UnsafeExternalSorter.create( TaskContext.get().taskMemoryManager(), SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, TaskContext.get(), null, null, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index c74ac8a282..233ac263aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -399,6 +399,7 @@ private[sql] class DynamicPartitionWriterContainer( sortingKeySchema, StructType.fromAttributes(dataColumns), SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes) while (iterator.hasNext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index fabd2fbe1e..fb65b50da8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -41,6 +41,7 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField val sorter = UnsafeExternalSorter.create( context.taskMemoryManager(), SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, context, null, null, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index e03bd6a3e7..476d93fc2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -120,7 +120,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { metricsSystem = null)) val sorter = new UnsafeKVExternalSorter( - keySchema, valueSchema, SparkEnv.get.blockManager, pageSize) + keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, pageSize) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index a29d55ee25..794fe264ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -279,6 +279,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( StructType.fromAttributes(partitionOutput), StructType.fromAttributes(dataOutput), SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes) while (iterator.hasNext) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index ace67a639c..c56520b1e2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -115,6 +115,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( assertValid() val hadoopConf = broadcastedHadoopConf.value val blockManager = SparkEnv.get.blockManager + val serializerManager = SparkEnv.get.serializerManager val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] val blockId = partition.blockId @@ -161,7 +162,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( logDebug(s"Stored partition data of $this into block manager with level $storageLevel") dataRead.rewind() } - blockManager.dataDeserialize(blockId, new ChunkedByteBuffer(dataRead)) + serializerManager.dataDeserialize(blockId, new ChunkedByteBuffer(dataRead)) .asInstanceOf[Iterator[T]] } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 6d4f4b99c1..85350ff658 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} @@ -123,6 +124,7 @@ private[streaming] case class WriteAheadLogBasedStoreResult( */ private[streaming] class WriteAheadLogBasedBlockHandler( blockManager: BlockManager, + serializerManager: SerializerManager, streamId: Int, storageLevel: StorageLevel, conf: SparkConf, @@ -173,10 +175,10 @@ private[streaming] class WriteAheadLogBasedBlockHandler( val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => numRecords = Some(arrayBuffer.size.toLong) - blockManager.dataSerialize(blockId, arrayBuffer.iterator) + serializerManager.dataSerialize(blockId, arrayBuffer.iterator) case IteratorBlock(iterator) => val countIterator = new CountingIterator(iterator) - val serializedBlock = blockManager.dataSerialize(blockId, countIterator) + val serializedBlock = serializerManager.dataSerialize(blockId, countIterator) numRecords = countIterator.count serializedBlock case ByteBufferBlock(byteBuffer) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index e41fd11963..4fb0f8caac 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -60,7 +60,7 @@ private[streaming] class ReceiverSupervisorImpl( "Please use streamingContext.checkpoint() to set the checkpoint directory. " + "See documentation for more details.") } - new WriteAheadLogBasedBlockHandler(env.blockManager, receiver.streamId, + new WriteAheadLogBasedBlockHandler(env.blockManager, env.serializerManager, receiver.streamId, receiver.storageLevel, env.conf, hadoopConf, checkpointDirOption.get) } else { new BlockManagerBasedBlockHandler(env.blockManager, receiver.storageLevel) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 122ca0627f..4e77cd6347 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -60,6 +60,7 @@ class ReceivedBlockHandlerSuite val mapOutputTracker = new MapOutputTrackerMaster(conf) val shuffleManager = new HashShuffleManager(conf) val serializer = new KryoSerializer(conf) + var serializerManager = new SerializerManager(serializer, conf) val manualClock = new ManualClock val blockManagerSize = 10000000 val blockManagerBuffer = new ArrayBuffer[BlockManager]() @@ -156,7 +157,7 @@ class ReceivedBlockHandlerSuite val reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf) val bytes = reader.read(fileSegment) reader.close() - blockManager.dataDeserialize(generateBlockId(), new ChunkedByteBuffer(bytes)).toList + serializerManager.dataDeserialize(generateBlockId(), new ChunkedByteBuffer(bytes)).toList } loggedData shouldEqual data } @@ -265,7 +266,6 @@ class ReceivedBlockHandlerSuite name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val serializerManager = new SerializerManager(serializer, conf) val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializerManager, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memManager.setMemoryStore(blockManager.memoryStore) @@ -335,7 +335,8 @@ class ReceivedBlockHandlerSuite } } - def dataToByteBuffer(b: Seq[String]) = blockManager.dataSerialize(generateBlockId, b.iterator) + def dataToByteBuffer(b: Seq[String]) = + serializerManager.dataSerialize(generateBlockId, b.iterator) val blocks = data.grouped(10).toSeq @@ -367,8 +368,8 @@ class ReceivedBlockHandlerSuite /** Instantiate a WriteAheadLogBasedBlockHandler and run a code with it */ private def withWriteAheadLogBasedBlockHandler(body: WriteAheadLogBasedBlockHandler => Unit) { require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = false) === 1) - val receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, 1, - storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock) + val receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, serializerManager, + 1, storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock) try { body(receivedBlockHandler) } finally { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index c4bf42d0f2..ce5a6e00fb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.util.Utils @@ -39,6 +40,7 @@ class WriteAheadLogBackedBlockRDDSuite var sparkContext: SparkContext = null var blockManager: BlockManager = null + var serializerManager: SerializerManager = null var dir: File = null override def beforeEach(): Unit = { @@ -58,6 +60,7 @@ class WriteAheadLogBackedBlockRDDSuite super.beforeAll() sparkContext = new SparkContext(conf) blockManager = sparkContext.env.blockManager + serializerManager = sparkContext.env.serializerManager } override def afterAll(): Unit = { @@ -65,6 +68,8 @@ class WriteAheadLogBackedBlockRDDSuite try { sparkContext.stop() System.clearProperty("spark.driver.port") + blockManager = null + serializerManager = null } finally { super.afterAll() } @@ -107,8 +112,6 @@ class WriteAheadLogBackedBlockRDDSuite * It can also test if the partitions that were read from the log were again stored in * block manager. * - * - * * @param numPartitions Number of partitions in RDD * @param numPartitionsInBM Number of partitions to write to the BlockManager. * Partitions 0 to (numPartitionsInBM-1) will be written to BlockManager @@ -223,7 +226,7 @@ class WriteAheadLogBackedBlockRDDSuite require(blockData.size === blockIds.size) val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) val segments = blockData.zip(blockIds).map { case (data, id) => - writer.write(blockManager.dataSerialize(id, data.iterator).toByteBuffer) + writer.write(serializerManager.dataSerialize(id, data.iterator).toByteBuffer) } writer.close() segments -- cgit v1.2.3