aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-03-23 10:15:23 -0700
committerJosh Rosen <joshrosen@databricks.com>2016-03-23 10:15:23 -0700
commit3de24ae2ed6c58fc96a7e50832afe42fe7af34fb (patch)
tree0eb9f5d7100301195e6d0c1b77114e2398f6edb3
parent6ce008ba46aa1fc8a5c222ce0f25a6d81f53588e (diff)
downloadspark-3de24ae2ed6c58fc96a7e50832afe42fe7af34fb.tar.gz
spark-3de24ae2ed6c58fc96a7e50832afe42fe7af34fb.tar.bz2
spark-3de24ae2ed6c58fc96a7e50832afe42fe7af34fb.zip
[SPARK-14075] Refactor MemoryStore to be testable independent of BlockManager
This patch refactors the `MemoryStore` so that it can be tested without needing to construct / mock an entire `BlockManager`. - The block manager's serialization- and compression-related methods have been moved from `BlockManager` to `SerializerManager`. - `BlockInfoManager `is now passed directly to classes that need it, rather than being passed via the `BlockManager`. - The `MemoryStore` now calls `dropFromMemory` via a new `BlockEvictionHandler` interface rather than directly calling the `BlockManager`. This change helps to enforce a narrow interface between the `MemoryStore` and `BlockManager` functionality and makes this interface easier to mock in tests. - Several of the block unrolling tests have been moved from `BlockManagerSuite` into a new `MemoryStoreSuite`. Author: Josh Rosen <joshrosen@databricks.com> Closes #11899 from JoshRosen/reduce-memorystore-blockmanager-coupling.
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java7
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java17
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java6
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java5
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala90
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala118
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala55
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala3
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java32
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java17
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java12
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala22
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala197
-rw-r--r--core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala302
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java1
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java8
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala1
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala6
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala2
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala11
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala9
31 files changed, 555 insertions, 402 deletions
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index de36814ecc..9aacb084f6 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -32,6 +32,7 @@ import org.apache.spark.SparkEnv;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
@@ -163,12 +164,14 @@ public final class BytesToBytesMap extends MemoryConsumer {
private long peakMemoryUsedBytes = 0L;
private final BlockManager blockManager;
+ private final SerializerManager serializerManager;
private volatile MapIterator destructiveIterator = null;
private LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
public BytesToBytesMap(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
+ SerializerManager serializerManager,
int initialCapacity,
double loadFactor,
long pageSizeBytes,
@@ -176,6 +179,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
super(taskMemoryManager, pageSizeBytes);
this.taskMemoryManager = taskMemoryManager;
this.blockManager = blockManager;
+ this.serializerManager = serializerManager;
this.loadFactor = loadFactor;
this.loc = new Location();
this.pageSizeBytes = pageSizeBytes;
@@ -209,6 +213,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
this(
taskMemoryManager,
SparkEnv.get() != null ? SparkEnv.get().blockManager() : null,
+ SparkEnv.get() != null ? SparkEnv.get().serializerManager() : null,
initialCapacity,
0.70,
pageSizeBytes,
@@ -271,7 +276,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
}
try {
Closeables.close(reader, /* swallowIOException = */ false);
- reader = spillWriters.getFirst().getReader(blockManager);
+ reader = spillWriters.getFirst().getReader(serializerManager);
recordsInPage = -1;
} catch (IOException e) {
// Scala iterator does not handle exception
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 927b19c4e8..ded8f0472b 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -31,6 +31,7 @@ import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
@@ -51,6 +52,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
private final RecordComparator recordComparator;
private final TaskMemoryManager taskMemoryManager;
private final BlockManager blockManager;
+ private final SerializerManager serializerManager;
private final TaskContext taskContext;
private ShuffleWriteMetrics writeMetrics;
@@ -78,6 +80,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
public static UnsafeExternalSorter createWithExistingInMemorySorter(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
+ SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
@@ -85,7 +88,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
long pageSizeBytes,
UnsafeInMemorySorter inMemorySorter) throws IOException {
UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
- taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
+ serializerManager, taskContext, recordComparator, prefixComparator, initialSize,
+ pageSizeBytes, inMemorySorter);
sorter.spill(Long.MAX_VALUE, sorter);
// The external sorter will be used to insert records, in-memory sorter is not needed.
sorter.inMemSorter = null;
@@ -95,18 +99,20 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
public static UnsafeExternalSorter create(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
+ SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes) {
- return new UnsafeExternalSorter(taskMemoryManager, blockManager,
+ return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
}
private UnsafeExternalSorter(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
+ SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
@@ -116,6 +122,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
super(taskMemoryManager, pageSizeBytes);
this.taskMemoryManager = taskMemoryManager;
this.blockManager = blockManager;
+ this.serializerManager = serializerManager;
this.taskContext = taskContext;
this.recordComparator = recordComparator;
this.prefixComparator = prefixComparator;
@@ -412,7 +419,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
final UnsafeSorterSpillMerger spillMerger =
new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size());
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
- spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
+ spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager));
}
if (inMemSorter != null) {
readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
@@ -463,7 +470,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
}
spillWriter.close();
spillWriters.add(spillWriter);
- nextUpstream = spillWriter.getReader(blockManager);
+ nextUpstream = spillWriter.getReader(serializerManager);
long released = 0L;
synchronized (UnsafeExternalSorter.this) {
@@ -549,7 +556,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
} else {
LinkedList<UnsafeSorterIterator> queue = new LinkedList<>();
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
- queue.add(spillWriter.getReader(blockManager));
+ queue.add(spillWriter.getReader(serializerManager));
}
if (inMemSorter != null) {
queue.add(inMemSorter.getSortedIterator());
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index 20ee1c8eb0..1d588c37c5 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -22,8 +22,8 @@ import java.io.*;
import com.google.common.io.ByteStreams;
import com.google.common.io.Closeables;
+import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockId;
-import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;
/**
@@ -46,13 +46,13 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
public UnsafeSorterSpillReader(
- BlockManager blockManager,
+ SerializerManager serializerManager,
File file,
BlockId blockId) throws IOException {
assert (file.length() > 0);
final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
try {
- this.in = blockManager.wrapForCompression(blockId, bs);
+ this.in = serializerManager.wrapForCompression(blockId, bs);
this.din = new DataInputStream(this.in);
numRecords = numRecordsRemaining = din.readInt();
} catch (IOException e) {
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index 234e21140a..9ba760e842 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -20,6 +20,7 @@ package org.apache.spark.util.collection.unsafe.sort;
import java.io.File;
import java.io.IOException;
+import org.apache.spark.serializer.SerializerManager;
import scala.Tuple2;
import org.apache.spark.executor.ShuffleWriteMetrics;
@@ -144,7 +145,7 @@ public final class UnsafeSorterSpillWriter {
return file;
}
- public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException {
- return new UnsafeSorterSpillReader(blockManager, file, blockId);
+ public UnsafeSorterSpillReader getReader(SerializerManager serializerManager) throws IOException {
+ return new UnsafeSorterSpillReader(serializerManager, file, blockId);
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index b9f115463a..27e5fa4c2b 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -17,17 +17,25 @@
package org.apache.spark.serializer
+import java.io.{BufferedInputStream, BufferedOutputStream, InputStream, OutputStream}
+import java.nio.ByteBuffer
+
import scala.reflect.ClassTag
import org.apache.spark.SparkConf
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.storage._
+import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer}
/**
- * Component that selects which [[Serializer]] to use for shuffles.
+ * Component which configures serialization and compression for various Spark components, including
+ * automatic selection of which [[Serializer]] to use for shuffles.
*/
private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) {
private[this] val kryoSerializer = new KryoSerializer(conf)
+ private[this] val stringClassTag: ClassTag[String] = implicitly[ClassTag[String]]
private[this] val primitiveAndPrimitiveArrayClassTags: Set[ClassTag[_]] = {
val primitiveClassTags = Set[ClassTag[_]](
ClassTag.Boolean,
@@ -44,7 +52,21 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
primitiveClassTags ++ arrayClassTags
}
- private[this] val stringClassTag: ClassTag[String] = implicitly[ClassTag[String]]
+ // Whether to compress broadcast variables that are stored
+ private[this] val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true)
+ // Whether to compress shuffle output that are stored
+ private[this] val compressShuffle = conf.getBoolean("spark.shuffle.compress", true)
+ // Whether to compress RDD partitions that are stored serialized
+ private[this] val compressRdds = conf.getBoolean("spark.rdd.compress", false)
+ // Whether to compress shuffle output temporarily spilled to disk
+ private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)
+
+ /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
+ * the initialization of the compression codec until it is first used. The reason is that a Spark
+ * program could be using a user-defined codec in a third party jar, which is loaded in
+ * Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been
+ * loaded yet. */
+ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)
private def canUseKryo(ct: ClassTag[_]): Boolean = {
primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag
@@ -68,4 +90,68 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
defaultSerializer
}
}
+
+ private def shouldCompress(blockId: BlockId): Boolean = {
+ blockId match {
+ case _: ShuffleBlockId => compressShuffle
+ case _: BroadcastBlockId => compressBroadcast
+ case _: RDDBlockId => compressRdds
+ case _: TempLocalBlockId => compressShuffleSpill
+ case _: TempShuffleBlockId => compressShuffle
+ case _ => false
+ }
+ }
+
+ /**
+ * Wrap an output stream for compression if block compression is enabled for its block type
+ */
+ def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
+ if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
+ }
+
+ /**
+ * Wrap an input stream for compression if block compression is enabled for its block type
+ */
+ def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
+ if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
+ }
+
+ /** Serializes into a stream. */
+ def dataSerializeStream[T: ClassTag](
+ blockId: BlockId,
+ outputStream: OutputStream,
+ values: Iterator[T]): Unit = {
+ val byteStream = new BufferedOutputStream(outputStream)
+ val ser = getSerializer(implicitly[ClassTag[T]]).newInstance()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ }
+
+ /** Serializes into a chunked byte buffer. */
+ def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = {
+ val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(1024 * 1024 * 4)
+ dataSerializeStream(blockId, byteArrayChunkOutputStream, values)
+ new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap))
+ }
+
+ /**
+ * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of
+ * the iterator is reached.
+ */
+ def dataDeserialize[T: ClassTag](blockId: BlockId, bytes: ChunkedByteBuffer): Iterator[T] = {
+ dataDeserializeStream[T](blockId, bytes.toInputStream(dispose = true))
+ }
+
+ /**
+ * Deserializes a InputStream into an iterator of values and disposes of it when the end of
+ * the iterator is reached.
+ */
+ def dataDeserializeStream[T: ClassTag](
+ blockId: BlockId,
+ inputStream: InputStream): Iterator[T] = {
+ val stream = new BufferedInputStream(inputStream)
+ getSerializer(implicitly[ClassTag[T]])
+ .newInstance()
+ .deserializeStream(wrapForCompression(blockId, stream))
+ .asIterator.asInstanceOf[Iterator[T]]
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 4054465c0f..637b2dfc19 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -19,7 +19,7 @@ package org.apache.spark.shuffle
import org.apache.spark._
import org.apache.spark.internal.Logging
-import org.apache.spark.serializer.Serializer
+import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
@@ -33,6 +33,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
startPartition: Int,
endPartition: Int,
context: TaskContext,
+ serializerManager: SerializerManager = SparkEnv.get.serializerManager,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
extends ShuffleReader[K, C] with Logging {
@@ -52,7 +53,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
// Wrap the streams for compression based on configuration
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
- blockManager.wrapForCompression(blockId, inputStream)
+ serializerManager.wrapForCompression(blockId, inputStream)
}
val serializerInstance = dep.serializer.newInstance()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 83f8c5c37d..eebb43e245 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -18,7 +18,6 @@
package org.apache.spark.storage
import java.io._
-import java.nio.ByteBuffer
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{Await, ExecutionContext, Future}
@@ -30,7 +29,6 @@ import scala.util.control.NonFatal
import org.apache.spark._
import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
import org.apache.spark.internal.Logging
-import org.apache.spark.io.CompressionCodec
import org.apache.spark.memory.MemoryManager
import org.apache.spark.network._
import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer}
@@ -38,11 +36,11 @@ import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
import org.apache.spark.rpc.RpcEnv
-import org.apache.spark.serializer.{Serializer, SerializerInstance, SerializerManager}
+import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage.memory._
import org.apache.spark.util._
-import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer}
+import org.apache.spark.util.io.ChunkedByteBuffer
/* Class for returning a fetched block and associated metrics. */
private[spark] class BlockResult(
@@ -68,7 +66,7 @@ private[spark] class BlockManager(
blockTransferService: BlockTransferService,
securityManager: SecurityManager,
numUsableCores: Int)
- extends BlockDataManager with Logging {
+ extends BlockDataManager with BlockEvictionHandler with Logging {
private[spark] val externalShuffleServiceEnabled =
conf.getBoolean("spark.shuffle.service.enabled", false)
@@ -80,13 +78,15 @@ private[spark] class BlockManager(
new DiskBlockManager(conf, deleteFilesOnStop)
}
+ // Visible for testing
private[storage] val blockInfoManager = new BlockInfoManager
private val futureExecutionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128))
// Actual storage of where blocks are kept
- private[spark] val memoryStore = new MemoryStore(conf, this, memoryManager)
+ private[spark] val memoryStore =
+ new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this)
private[spark] val diskStore = new DiskStore(conf, diskBlockManager)
memoryManager.setMemoryStore(memoryStore)
@@ -126,14 +126,6 @@ private[spark] class BlockManager(
blockTransferService
}
- // Whether to compress broadcast variables that are stored
- private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true)
- // Whether to compress shuffle output that are stored
- private val compressShuffle = conf.getBoolean("spark.shuffle.compress", true)
- // Whether to compress RDD partitions that are stored serialized
- private val compressRdds = conf.getBoolean("spark.rdd.compress", false)
- // Whether to compress shuffle output temporarily spilled to disk
- private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)
// Max number of failures before this block manager refreshes the block locations from the driver
private val maxFailuresBeforeLocationRefresh =
conf.getInt("spark.block.failures.beforeLocationRefresh", 5)
@@ -152,13 +144,6 @@ private[spark] class BlockManager(
private val peerFetchLock = new Object
private var lastPeerFetchTime = 0L
- /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
- * the initialization of the compression codec until it is first used. The reason is that a Spark
- * program could be using a user-defined codec in a third party jar, which is loaded in
- * Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been
- * loaded yet. */
- private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)
-
/**
* Initializes the BlockManager with the given appId. This is not performed in the constructor as
* the appId may not be known at BlockManager instantiation time (in particular for the driver,
@@ -286,7 +271,7 @@ private[spark] class BlockManager(
shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
} else {
getLocalBytes(blockId) match {
- case Some(buffer) => new BlockManagerManagedBuffer(this, blockId, buffer)
+ case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer)
case None => throw new BlockNotFoundException(blockId.toString)
}
}
@@ -422,7 +407,8 @@ private[spark] class BlockManager(
val iter: Iterator[Any] = if (level.deserialized) {
memoryStore.getValues(blockId).get
} else {
- dataDeserialize(blockId, memoryStore.getBytes(blockId).get)(info.classTag)
+ serializerManager.dataDeserialize(
+ blockId, memoryStore.getBytes(blockId).get)(info.classTag)
}
val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId))
Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
@@ -430,11 +416,11 @@ private[spark] class BlockManager(
val iterToReturn: Iterator[Any] = {
val diskBytes = diskStore.getBytes(blockId)
if (level.deserialized) {
- val diskValues = dataDeserialize(blockId, diskBytes)(info.classTag)
+ val diskValues = serializerManager.dataDeserialize(blockId, diskBytes)(info.classTag)
maybeCacheDiskValuesInMemory(info, blockId, level, diskValues)
} else {
val bytes = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes)
- dataDeserialize(blockId, bytes)(info.classTag)
+ serializerManager.dataDeserialize(blockId, bytes)(info.classTag)
}
}
val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId))
@@ -486,7 +472,7 @@ private[spark] class BlockManager(
diskStore.getBytes(blockId)
} else if (level.useMemory && memoryStore.contains(blockId)) {
// The block was not found on disk, so serialize an in-memory copy:
- dataSerialize(blockId, memoryStore.getValues(blockId).get)
+ serializerManager.dataSerialize(blockId, memoryStore.getValues(blockId).get)
} else {
releaseLock(blockId)
throw new SparkException(s"Block $blockId was not found even though it's read-locked")
@@ -510,7 +496,8 @@ private[spark] class BlockManager(
*/
private def getRemoteValues(blockId: BlockId): Option[BlockResult] = {
getRemoteBytes(blockId).map { data =>
- new BlockResult(dataDeserialize(blockId, data), DataReadMethod.Network, data.size)
+ new BlockResult(
+ serializerManager.dataDeserialize(blockId, data), DataReadMethod.Network, data.size)
}
}
@@ -699,7 +686,8 @@ private[spark] class BlockManager(
serializerInstance: SerializerInstance,
bufferSize: Int,
writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
- val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
+ val compressStream: OutputStream => OutputStream =
+ serializerManager.wrapForCompression(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream,
syncWrites, writeMetrics, blockId)
@@ -757,7 +745,7 @@ private[spark] class BlockManager(
// Put it in memory first, even if it also has useDisk set to true;
// We will drop it to disk later if the memory store can't hold it.
val putSucceeded = if (level.deserialized) {
- val values = dataDeserialize(blockId, bytes)(classTag)
+ val values = serializerManager.dataDeserialize(blockId, bytes)(classTag)
memoryStore.putIterator(blockId, values, level, classTag) match {
case Right(_) => true
case Left(iter) =>
@@ -896,7 +884,7 @@ private[spark] class BlockManager(
if (level.useDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
diskStore.put(blockId) { fileOutputStream =>
- dataSerializeStream(blockId, fileOutputStream, iter)(classTag)
+ serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag)
}
size = diskStore.getSize(blockId)
} else {
@@ -905,7 +893,7 @@ private[spark] class BlockManager(
}
} else if (level.useDisk) {
diskStore.put(blockId) { fileOutputStream =>
- dataSerializeStream(blockId, fileOutputStream, iterator())(classTag)
+ serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag)
}
size = diskStore.getSize(blockId)
}
@@ -1167,7 +1155,7 @@ private[spark] class BlockManager(
*
* @return the block's new effective StorageLevel.
*/
- private[storage] def dropFromMemory[T: ClassTag](
+ private[storage] override def dropFromMemory[T: ClassTag](
blockId: BlockId,
data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = {
logInfo(s"Dropping block $blockId from memory")
@@ -1181,7 +1169,7 @@ private[spark] class BlockManager(
data() match {
case Left(elements) =>
diskStore.put(blockId) { fileOutputStream =>
- dataSerializeStream(
+ serializerManager.dataSerializeStream(
blockId,
fileOutputStream,
elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]])
@@ -1264,70 +1252,6 @@ private[spark] class BlockManager(
}
}
- private def shouldCompress(blockId: BlockId): Boolean = {
- blockId match {
- case _: ShuffleBlockId => compressShuffle
- case _: BroadcastBlockId => compressBroadcast
- case _: RDDBlockId => compressRdds
- case _: TempLocalBlockId => compressShuffleSpill
- case _: TempShuffleBlockId => compressShuffle
- case _ => false
- }
- }
-
- /**
- * Wrap an output stream for compression if block compression is enabled for its block type
- */
- def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
- if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
- }
-
- /**
- * Wrap an input stream for compression if block compression is enabled for its block type
- */
- def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
- if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
- }
-
- /** Serializes into a stream. */
- def dataSerializeStream[T: ClassTag](
- blockId: BlockId,
- outputStream: OutputStream,
- values: Iterator[T]): Unit = {
- val byteStream = new BufferedOutputStream(outputStream)
- val ser = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance()
- ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
- }
-
- /** Serializes into a chunked byte buffer. */
- def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = {
- val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(1024 * 1024 * 4)
- dataSerializeStream(blockId, byteArrayChunkOutputStream, values)
- new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap))
- }
-
- /**
- * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of
- * the iterator is reached.
- */
- def dataDeserialize[T: ClassTag](blockId: BlockId, bytes: ChunkedByteBuffer): Iterator[T] = {
- dataDeserializeStream[T](blockId, bytes.toInputStream(dispose = true))
- }
-
- /**
- * Deserializes a InputStream into an iterator of values and disposes of it when the end of
- * the iterator is reached.
- */
- def dataDeserializeStream[T: ClassTag](
- blockId: BlockId,
- inputStream: InputStream): Iterator[T] = {
- val stream = new BufferedInputStream(inputStream)
- serializerManager.getSerializer(implicitly[ClassTag[T]])
- .newInstance()
- .deserializeStream(wrapForCompression(blockId, stream))
- .asIterator.asInstanceOf[Iterator[T]]
- }
-
def stop(): Unit = {
blockTransferService.close()
if (shuffleClient ne blockTransferService) {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala
index 12594e6a2b..f66f942798 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala
@@ -29,19 +29,19 @@ import org.apache.spark.util.io.ChunkedByteBuffer
* to the network layer's notion of retain / release counts.
*/
private[storage] class BlockManagerManagedBuffer(
- blockManager: BlockManager,
+ blockInfoManager: BlockInfoManager,
blockId: BlockId,
chunkedBuffer: ChunkedByteBuffer) extends NettyManagedBuffer(chunkedBuffer.toNetty) {
override def retain(): ManagedBuffer = {
super.retain()
- val locked = blockManager.blockInfoManager.lockForReading(blockId, blocking = false)
+ val locked = blockInfoManager.lockForReading(blockId, blocking = false)
assert(locked.isDefined)
this
}
override def release(): ManagedBuffer = {
- blockManager.releaseLock(blockId)
+ blockInfoManager.unlock(blockId)
super.release()
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index d370ee912a..90016cbeb8 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -26,7 +26,8 @@ import scala.reflect.ClassTag
import org.apache.spark.{SparkConf, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.memory.MemoryManager
-import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel}
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
import org.apache.spark.util.collection.SizeTrackingVector
import org.apache.spark.util.io.ChunkedByteBuffer
@@ -44,14 +45,33 @@ private case class SerializedMemoryEntry[T](
size: Long,
classTag: ClassTag[T]) extends MemoryEntry[T]
+private[storage] trait BlockEvictionHandler {
+ /**
+ * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory
+ * store reaches its limit and needs to free up space.
+ *
+ * If `data` is not put on disk, it won't be created.
+ *
+ * The caller of this method must hold a write lock on the block before calling this method.
+ * This method does not release the write lock.
+ *
+ * @return the block's new effective StorageLevel.
+ */
+ private[storage] def dropFromMemory[T: ClassTag](
+ blockId: BlockId,
+ data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel
+}
+
/**
* Stores blocks in memory, either as Arrays of deserialized Java objects or as
* serialized ByteBuffers.
*/
private[spark] class MemoryStore(
conf: SparkConf,
- blockManager: BlockManager,
- memoryManager: MemoryManager)
+ blockInfoManager: BlockInfoManager,
+ serializerManager: SerializerManager,
+ memoryManager: MemoryManager,
+ blockEvictionHandler: BlockEvictionHandler)
extends Logging {
// Note: all changes to memory allocations, notably putting blocks, evicting blocks, and
@@ -117,7 +137,7 @@ private[spark] class MemoryStore(
entries.put(blockId, entry)
}
logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format(
- blockId, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed)))
+ blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed)))
true
} else {
false
@@ -201,7 +221,7 @@ private[spark] class MemoryStore(
val entry = if (level.deserialized) {
new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag)
} else {
- val bytes = blockManager.dataSerialize(blockId, arrayValues.iterator)(classTag)
+ val bytes = serializerManager.dataSerialize(blockId, arrayValues.iterator)(classTag)
new SerializedMemoryEntry[T](bytes, bytes.size, classTag)
}
val size = entry.size
@@ -237,7 +257,10 @@ private[spark] class MemoryStore(
}
val bytesOrValues = if (level.deserialized) "values" else "bytes"
logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format(
- blockId, bytesOrValues, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed)))
+ blockId,
+ bytesOrValues,
+ Utils.bytesToString(size),
+ Utils.bytesToString(maxMemory - blocksMemoryUsed)))
Right(size)
} else {
assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask,
@@ -284,7 +307,7 @@ private[spark] class MemoryStore(
}
if (entry != null) {
memoryManager.releaseStorageMemory(entry.size)
- logDebug(s"Block $blockId of size ${entry.size} dropped " +
+ logInfo(s"Block $blockId of size ${entry.size} dropped " +
s"from memory (free ${maxMemory - blocksMemoryUsed})")
true
} else {
@@ -339,7 +362,7 @@ private[spark] class MemoryStore(
// We don't want to evict blocks which are currently being read, so we need to obtain
// an exclusive write lock on blocks which are candidates for eviction. We perform a
// non-blocking "tryLock" here in order to ignore blocks which are locked for reading:
- if (blockManager.blockInfoManager.lockForWriting(blockId, blocking = false).isDefined) {
+ if (blockInfoManager.lockForWriting(blockId, blocking = false).isDefined) {
selectedBlocks += blockId
freedMemory += pair.getValue.size
}
@@ -353,20 +376,21 @@ private[spark] class MemoryStore(
case SerializedMemoryEntry(buffer, _, _) => Right(buffer)
}
val newEffectiveStorageLevel =
- blockManager.dropFromMemory(blockId, () => data)(entry.classTag)
+ blockEvictionHandler.dropFromMemory(blockId, () => data)(entry.classTag)
if (newEffectiveStorageLevel.isValid) {
// The block is still present in at least one store, so release the lock
// but don't delete the block info
- blockManager.releaseLock(blockId)
+ blockInfoManager.unlock(blockId)
} else {
// The block isn't present in any store, so delete the block info so that the
// block can be stored again
- blockManager.blockInfoManager.removeBlock(blockId)
+ blockInfoManager.removeBlock(blockId)
}
}
if (freedMemory >= space) {
- logInfo(s"${selectedBlocks.size} blocks selected for dropping")
+ logInfo(s"${selectedBlocks.size} blocks selected for dropping " +
+ s"(${Utils.bytesToString(freedMemory)} bytes)")
for (blockId <- selectedBlocks) {
val entry = entries.synchronized { entries.get(blockId) }
// This should never be null as only one task should be dropping
@@ -376,14 +400,15 @@ private[spark] class MemoryStore(
dropBlock(blockId, entry)
}
}
+ logInfo(s"After dropping ${selectedBlocks.size} blocks, " +
+ s"free memory is ${Utils.bytesToString(maxMemory - blocksMemoryUsed)}")
freedMemory
} else {
blockId.foreach { id =>
- logInfo(s"Will not store $id as it would require dropping another block " +
- "from the same RDD")
+ logInfo(s"Will not store $id")
}
selectedBlocks.foreach { id =>
- blockManager.releaseLock(id)
+ blockInfoManager.unlock(id)
}
0L
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 531f1c4dd2..95351e9826 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -31,7 +31,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
-import org.apache.spark.serializer.{DeserializationStream, Serializer}
+import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerManager}
import org.apache.spark.storage.{BlockId, BlockManager}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator
@@ -59,7 +59,8 @@ class ExternalAppendOnlyMap[K, V, C](
mergeCombiners: (C, C) => C,
serializer: Serializer = SparkEnv.get.serializer,
blockManager: BlockManager = SparkEnv.get.blockManager,
- context: TaskContext = TaskContext.get())
+ context: TaskContext = TaskContext.get(),
+ serializerManager: SerializerManager = SparkEnv.get.serializerManager)
extends Iterable[(K, C)]
with Serializable
with Logging
@@ -458,7 +459,7 @@ class ExternalAppendOnlyMap[K, V, C](
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
- val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream)
+ val compressedStream = serializerManager.wrapForCompression(blockId, bufferedStream)
ser.deserializeStream(compressedStream)
} else {
// No more batches left
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 8cdc4663e6..561ba22df5 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -108,6 +108,7 @@ private[spark] class ExternalSorter[K, V, C](
private val blockManager = SparkEnv.get.blockManager
private val diskBlockManager = blockManager.diskBlockManager
+ private val serializerManager = SparkEnv.get.serializerManager
private val serInstance = serializer.newInstance()
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
@@ -503,7 +504,7 @@ private[spark] class ExternalSorter[K, V, C](
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
- val compressedStream = blockManager.wrapForCompression(spill.blockId, bufferedStream)
+ val compressedStream = serializerManager.wrapForCompression(spill.blockId, bufferedStream)
serInstance.deserializeStream(compressedStream)
} else {
// No more batches left
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 47c695ad4e..44733dcdaf 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -70,6 +70,7 @@ public class UnsafeShuffleWriterSuite {
final LinkedList<File> spillFilesCreated = new LinkedList<>();
SparkConf conf;
final Serializer serializer = new KryoSerializer(new SparkConf());
+ final SerializerManager serializerManager = new SerializerManager(serializer, new SparkConf());
TaskMetrics taskMetrics;
@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
@@ -111,7 +112,7 @@ public class UnsafeShuffleWriterSuite {
.set("spark.memory.offHeap.enabled", "false");
taskMetrics = new TaskMetrics();
memoryManager = new TestMemoryManager(conf);
- taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
+ taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
when(blockManager.getDiskWriter(
@@ -135,35 +136,6 @@ public class UnsafeShuffleWriterSuite {
);
}
});
- when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer(
- new Answer<InputStream>() {
- @Override
- public InputStream answer(InvocationOnMock invocation) throws Throwable {
- assertTrue(invocation.getArguments()[0] instanceof TempShuffleBlockId);
- InputStream is = (InputStream) invocation.getArguments()[1];
- if (conf.getBoolean("spark.shuffle.compress", true)) {
- return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is);
- } else {
- return is;
- }
- }
- }
- );
-
- when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer(
- new Answer<OutputStream>() {
- @Override
- public OutputStream answer(InvocationOnMock invocation) throws Throwable {
- assertTrue(invocation.getArguments()[0] instanceof TempShuffleBlockId);
- OutputStream os = (OutputStream) invocation.getArguments()[1];
- if (conf.getBoolean("spark.shuffle.compress", true)) {
- return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os);
- } else {
- return os;
- }
- }
- }
- );
when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
doAnswer(new Answer<Void>() {
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 6667179b9d..449fb45c30 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -19,7 +19,6 @@ package org.apache.spark.unsafe.map;
import java.io.File;
import java.io.IOException;
-import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.*;
@@ -42,7 +41,9 @@ import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.serializer.JavaSerializer;
import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
@@ -51,7 +52,6 @@ import org.apache.spark.util.Utils;
import static org.hamcrest.Matchers.greaterThan;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
-import static org.mockito.AdditionalAnswers.returnsSecondArg;
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
@@ -64,6 +64,9 @@ public abstract class AbstractBytesToBytesMapSuite {
private TestMemoryManager memoryManager;
private TaskMemoryManager taskMemoryManager;
+ private SerializerManager serializerManager = new SerializerManager(
+ new JavaSerializer(new SparkConf()),
+ new SparkConf().set("spark.shuffle.spill.compress", "false"));
private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
final LinkedList<File> spillFilesCreated = new LinkedList<>();
@@ -85,7 +88,9 @@ public abstract class AbstractBytesToBytesMapSuite {
new TestMemoryManager(
new SparkConf()
.set("spark.memory.offHeap.enabled", "" + useOffHeapMemoryAllocator())
- .set("spark.memory.offHeap.size", "256mb"));
+ .set("spark.memory.offHeap.size", "256mb")
+ .set("spark.shuffle.spill.compress", "false")
+ .set("spark.shuffle.compress", "false"));
taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test");
@@ -124,8 +129,6 @@ public abstract class AbstractBytesToBytesMapSuite {
);
}
});
- when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
- .then(returnsSecondArg());
}
@After
@@ -546,8 +549,8 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void spillInIterator() throws IOException {
- BytesToBytesMap map =
- new BytesToBytesMap(taskMemoryManager, blockManager, 1, 0.75, 1024, false);
+ BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024, false);
try {
int i;
for (i = 0; i < 1024; i++) {
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index db50e551f2..a2253d8559 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -19,7 +19,6 @@ package org.apache.spark.util.collection.unsafe.sort;
import java.io.File;
import java.io.IOException;
-import java.io.InputStream;
import java.io.OutputStream;
import java.util.Arrays;
import java.util.LinkedList;
@@ -43,14 +42,15 @@ import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.serializer.JavaSerializer;
import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.Utils;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.junit.Assert.*;
-import static org.mockito.AdditionalAnswers.returnsSecondArg;
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Mockito.*;
@@ -60,6 +60,9 @@ public class UnsafeExternalSorterSuite {
final TestMemoryManager memoryManager =
new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false"));
final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
+ final SerializerManager serializerManager = new SerializerManager(
+ new JavaSerializer(new SparkConf()),
+ new SparkConf().set("spark.shuffle.spill.compress", "false"));
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
final PrefixComparator prefixComparator = new PrefixComparator() {
@Override
@@ -135,8 +138,6 @@ public class UnsafeExternalSorterSuite {
);
}
});
- when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
- .then(returnsSecondArg());
}
@After
@@ -172,6 +173,7 @@ public class UnsafeExternalSorterSuite {
return UnsafeExternalSorter.create(
taskMemoryManager,
blockManager,
+ serializerManager,
taskContext,
recordComparator,
prefixComparator,
@@ -374,6 +376,7 @@ public class UnsafeExternalSorterSuite {
final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
taskMemoryManager,
blockManager,
+ serializerManager,
taskContext,
null,
null,
@@ -408,6 +411,7 @@ public class UnsafeExternalSorterSuite {
final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
taskMemoryManager,
blockManager,
+ serializerManager,
taskContext,
recordComparator,
prefixComparator,
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 2732cd6749..3dded4d486 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -194,10 +194,11 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
val blockId = blockIds(0)
val blockManager = SparkEnv.get.blockManager
val blockTransfer = SparkEnv.get.blockTransferService
+ val serializerManager = SparkEnv.get.serializerManager
blockManager.master.getLocations(blockId).foreach { cmId =>
val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId,
blockId.toString)
- val deserialized = blockManager.dataDeserialize[Int](blockId,
+ val deserialized = serializerManager.dataDeserialize[Int](blockId,
new ChunkedByteBuffer(bytes.nioByteBuffer())).toList
assert(deserialized === (1 to 100).toList)
}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
index 08f52c92e1..dba1172d5f 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
@@ -20,14 +20,11 @@ package org.apache.spark.shuffle
import java.io.{ByteArrayOutputStream, InputStream}
import java.nio.ByteBuffer
-import org.mockito.Matchers.{eq => meq, _}
import org.mockito.Mockito.{mock, when}
-import org.mockito.invocation.InvocationOnMock
-import org.mockito.stubbing.Answer
import org.apache.spark._
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
-import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId}
/**
@@ -77,13 +74,6 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
// can ensure retain() and release() are properly called.
val blockManager = mock(classOf[BlockManager])
- // Create a return function to use for the mocked wrapForCompression method that just returns
- // the original input stream.
- val dummyCompressionFunction = new Answer[InputStream] {
- override def answer(invocation: InvocationOnMock): InputStream =
- invocation.getArguments()(1).asInstanceOf[InputStream]
- }
-
// Create a buffer with some randomly generated key-value pairs to use as the shuffle data
// from each mappers (all mappers return the same shuffle data).
val byteOutputStream = new ByteArrayOutputStream()
@@ -105,9 +95,6 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
// fetch shuffle data.
val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer)
- when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream])))
- .thenAnswer(dummyCompressionFunction)
-
managedBuffer
}
@@ -133,11 +120,18 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
new BaseShuffleHandle(shuffleId, numMaps, dependency)
}
+ val serializerManager = new SerializerManager(
+ serializer,
+ new SparkConf()
+ .set("spark.shuffle.compress", "false")
+ .set("spark.shuffle.spill.compress", "false"))
+
val shuffleReader = new BlockStoreShuffleReader(
shuffleHandle,
reduceId,
reduceId + 1,
TaskContext.empty(),
+ serializerManager,
blockManager,
mapOutputTracker)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 9419dfaa00..94f6f87740 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -1033,138 +1033,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store")
}
- test("reserve/release unroll memory") {
- store = makeBlockManager(12000)
- val memoryStore = store.memoryStore
- assert(memoryStore.currentUnrollMemory === 0)
- assert(memoryStore.currentUnrollMemoryForThisTask === 0)
-
- def reserveUnrollMemoryForThisTask(memory: Long): Boolean = {
- memoryStore.reserveUnrollMemoryForThisTask(TestBlockId(""), memory)
- }
-
- // Reserve
- assert(reserveUnrollMemoryForThisTask(100))
- assert(memoryStore.currentUnrollMemoryForThisTask === 100)
- assert(reserveUnrollMemoryForThisTask(200))
- assert(memoryStore.currentUnrollMemoryForThisTask === 300)
- assert(reserveUnrollMemoryForThisTask(500))
- assert(memoryStore.currentUnrollMemoryForThisTask === 800)
- assert(!reserveUnrollMemoryForThisTask(1000000))
- assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted
- // Release
- memoryStore.releaseUnrollMemoryForThisTask(100)
- assert(memoryStore.currentUnrollMemoryForThisTask === 700)
- memoryStore.releaseUnrollMemoryForThisTask(100)
- assert(memoryStore.currentUnrollMemoryForThisTask === 600)
- // Reserve again
- assert(reserveUnrollMemoryForThisTask(4400))
- assert(memoryStore.currentUnrollMemoryForThisTask === 5000)
- assert(!reserveUnrollMemoryForThisTask(20000))
- assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted
- // Release again
- memoryStore.releaseUnrollMemoryForThisTask(1000)
- assert(memoryStore.currentUnrollMemoryForThisTask === 4000)
- memoryStore.releaseUnrollMemoryForThisTask() // release all
- assert(memoryStore.currentUnrollMemoryForThisTask === 0)
- }
-
- test("safely unroll blocks") {
- store = makeBlockManager(12000)
- val smallList = List.fill(40)(new Array[Byte](100))
- val bigList = List.fill(40)(new Array[Byte](1000))
- val memoryStore = store.memoryStore
- assert(memoryStore.currentUnrollMemoryForThisTask === 0)
-
- // Unroll with all the space in the world. This should succeed.
- var putResult =
- memoryStore.putIterator("unroll", smallList.iterator, StorageLevel.MEMORY_ONLY, ClassTag.Any)
- assert(putResult.isRight)
- assert(memoryStore.currentUnrollMemoryForThisTask === 0)
- smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
- assert(e === a, "getValues() did not return original values!")
- }
- assert(memoryStore.remove("unroll"))
-
- // Unroll with not enough space. This should succeed after kicking out someBlock1.
- assert(store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY))
- assert(store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY))
- putResult =
- memoryStore.putIterator("unroll", smallList.iterator, StorageLevel.MEMORY_ONLY, ClassTag.Any)
- assert(putResult.isRight)
- assert(memoryStore.currentUnrollMemoryForThisTask === 0)
- assert(memoryStore.contains("someBlock2"))
- assert(!memoryStore.contains("someBlock1"))
- smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
- assert(e === a, "getValues() did not return original values!")
- }
- assert(memoryStore.remove("unroll"))
-
- // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 =
- // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator.
- // In the mean time, however, we kicked out someBlock2 before giving up.
- assert(store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY))
- putResult =
- memoryStore.putIterator("unroll", bigList.iterator, StorageLevel.MEMORY_ONLY, ClassTag.Any)
- assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
- assert(!memoryStore.contains("someBlock2"))
- assert(putResult.isLeft)
- bigList.iterator.zip(putResult.left.get).foreach { case (e, a) =>
- assert(e === a, "putIterator() did not return original values!")
- }
- // The unroll memory was freed once the iterator returned by putIterator() was fully traversed.
- assert(memoryStore.currentUnrollMemoryForThisTask === 0)
- }
-
- test("safely unroll blocks through putIterator") {
- store = makeBlockManager(12000)
- val memOnly = StorageLevel.MEMORY_ONLY
- val memoryStore = store.memoryStore
- val smallList = List.fill(40)(new Array[Byte](100))
- val bigList = List.fill(40)(new Array[Byte](1000))
- def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]]
- def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]]
- assert(memoryStore.currentUnrollMemoryForThisTask === 0)
-
- // Unroll with plenty of space. This should succeed and cache both blocks.
- val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, ClassTag.Any)
- val result2 = memoryStore.putIterator("b2", smallIterator, memOnly, ClassTag.Any)
- assert(memoryStore.contains("b1"))
- assert(memoryStore.contains("b2"))
- assert(result1.isRight) // unroll was successful
- assert(result2.isRight)
- assert(memoryStore.currentUnrollMemoryForThisTask === 0)
-
- // Re-put these two blocks so block manager knows about them too. Otherwise, block manager
- // would not know how to drop them from memory later.
- memoryStore.remove("b1")
- memoryStore.remove("b2")
- store.putIterator("b1", smallIterator, memOnly)
- store.putIterator("b2", smallIterator, memOnly)
-
- // Unroll with not enough space. This should succeed but kick out b1 in the process.
- val result3 = memoryStore.putIterator("b3", smallIterator, memOnly, ClassTag.Any)
- assert(result3.isRight)
- assert(!memoryStore.contains("b1"))
- assert(memoryStore.contains("b2"))
- assert(memoryStore.contains("b3"))
- assert(memoryStore.currentUnrollMemoryForThisTask === 0)
- memoryStore.remove("b3")
- store.putIterator("b3", smallIterator, memOnly)
-
- // Unroll huge block with not enough space. This should fail and kick out b2 in the process.
- val result4 = memoryStore.putIterator("b4", bigIterator, memOnly, ClassTag.Any)
- assert(result4.isLeft) // unroll was unsuccessful
- assert(!memoryStore.contains("b1"))
- assert(!memoryStore.contains("b2"))
- assert(memoryStore.contains("b3"))
- assert(!memoryStore.contains("b4"))
- assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
- }
-
- /**
- * This test is essentially identical to the preceding one, except that it uses MEMORY_AND_DISK.
- */
test("safely unroll blocks through putIterator (disk)") {
store = makeBlockManager(12000)
val memAndDisk = StorageLevel.MEMORY_AND_DISK
@@ -1203,72 +1071,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(!memoryStore.contains("b2"))
assert(memoryStore.contains("b3"))
assert(!memoryStore.contains("b4"))
- assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
- result4.left.get.close()
- assert(memoryStore.currentUnrollMemoryForThisTask === 0) // close released the unroll memory
- }
-
- test("multiple unrolls by the same thread") {
- store = makeBlockManager(12000)
- val memOnly = StorageLevel.MEMORY_ONLY
- val memoryStore = store.memoryStore
- val smallList = List.fill(40)(new Array[Byte](100))
- def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]]
- assert(memoryStore.currentUnrollMemoryForThisTask === 0)
-
- // All unroll memory used is released because putIterator did not return an iterator
- assert(memoryStore.putIterator("b1", smallIterator, memOnly, ClassTag.Any).isRight)
- assert(memoryStore.currentUnrollMemoryForThisTask === 0)
- assert(memoryStore.putIterator("b2", smallIterator, memOnly, ClassTag.Any).isRight)
- assert(memoryStore.currentUnrollMemoryForThisTask === 0)
-
- // Unroll memory is not released because putIterator returned an iterator
- // that still depends on the underlying vector used in the process
- assert(memoryStore.putIterator("b3", smallIterator, memOnly, ClassTag.Any).isLeft)
- val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask
- assert(unrollMemoryAfterB3 > 0)
-
- // The unroll memory owned by this thread builds on top of its value after the previous unrolls
- assert(memoryStore.putIterator("b4", smallIterator, memOnly, ClassTag.Any).isLeft)
- val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask
- assert(unrollMemoryAfterB4 > unrollMemoryAfterB3)
-
- // ... but only to a certain extent (until we run out of free space to grant new unroll memory)
- assert(memoryStore.putIterator("b5", smallIterator, memOnly, ClassTag.Any).isLeft)
- val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask
- assert(memoryStore.putIterator("b6", smallIterator, memOnly, ClassTag.Any).isLeft)
- val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask
- assert(memoryStore.putIterator("b7", smallIterator, memOnly, ClassTag.Any).isLeft)
- val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask
- assert(unrollMemoryAfterB5 === unrollMemoryAfterB4)
- assert(unrollMemoryAfterB6 === unrollMemoryAfterB4)
- assert(unrollMemoryAfterB7 === unrollMemoryAfterB4)
- }
-
- test("lazily create a big ByteBuffer to avoid OOM if it cannot be put into MemoryStore") {
- store = makeBlockManager(12000)
- val memoryStore = store.memoryStore
- val blockId = BlockId("rdd_3_10")
- store.blockInfoManager.lockNewBlockForWriting(
- blockId, new BlockInfo(StorageLevel.MEMORY_ONLY, ClassTag.Any, tellMaster = false))
- memoryStore.putBytes(blockId, 13000, () => {
- fail("A big ByteBuffer that cannot be put into MemoryStore should not be created")
- })
- }
-
- test("put a small ByteBuffer to MemoryStore") {
- store = makeBlockManager(12000)
- val memoryStore = store.memoryStore
- val blockId = BlockId("rdd_3_10")
- var bytes: ChunkedByteBuffer = null
- memoryStore.putBytes(blockId, 10000, () => {
- bytes = new ChunkedByteBuffer(ByteBuffer.allocate(10000))
- bytes
- })
- assert(memoryStore.getSize(blockId) === 10000)
}
- test("read-locked blocks cannot be evicted from the MemoryStore") {
+ test("read-locked blocks cannot be evicted from memory") {
store = makeBlockManager(12000)
val arr = new Array[Byte](4000)
// First store a1 and a2, both in memory, and a3, on disk only
diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
new file mode 100644
index 0000000000..b4ab67ca15
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
@@ -0,0 +1,302 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+
+import scala.language.implicitConversions
+import scala.language.postfixOps
+import scala.language.reflectiveCalls
+import scala.reflect.ClassTag
+
+import org.scalatest._
+
+import org.apache.spark._
+import org.apache.spark.memory.StaticMemoryManager
+import org.apache.spark.serializer.{KryoSerializer, SerializerManager}
+import org.apache.spark.storage.memory.{BlockEvictionHandler, MemoryStore, PartiallyUnrolledIterator}
+import org.apache.spark.util._
+import org.apache.spark.util.io.ChunkedByteBuffer
+
+class MemoryStoreSuite
+ extends SparkFunSuite
+ with PrivateMethodTester
+ with BeforeAndAfterEach
+ with ResetSystemProperties {
+
+ var conf: SparkConf = new SparkConf(false)
+ .set("spark.test.useCompressedOops", "true")
+ .set("spark.storage.unrollFraction", "0.4")
+ .set("spark.storage.unrollMemoryThreshold", "512")
+
+ // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
+ val serializer = new KryoSerializer(new SparkConf(false).set("spark.kryoserializer.buffer", "1m"))
+
+ // Implicitly convert strings to BlockIds for test clarity.
+ implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)
+ def rdd(rddId: Int, splitId: Int): RDDBlockId = RDDBlockId(rddId, splitId)
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
+ System.setProperty("os.arch", "amd64")
+ val initialize = PrivateMethod[Unit]('initialize)
+ SizeEstimator invokePrivate initialize()
+ }
+
+ def makeMemoryStore(maxMem: Long): (MemoryStore, BlockInfoManager) = {
+ val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
+ val serializerManager = new SerializerManager(serializer, conf)
+ val blockInfoManager = new BlockInfoManager
+ val blockEvictionHandler = new BlockEvictionHandler {
+ var memoryStore: MemoryStore = _
+ override private[storage] def dropFromMemory[T: ClassTag](
+ blockId: BlockId,
+ data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = {
+ memoryStore.remove(blockId)
+ StorageLevel.NONE
+ }
+ }
+ val memoryStore =
+ new MemoryStore(conf, blockInfoManager, serializerManager, memManager, blockEvictionHandler)
+ memManager.setMemoryStore(memoryStore)
+ blockEvictionHandler.memoryStore = memoryStore
+ (memoryStore, blockInfoManager)
+ }
+
+ test("reserve/release unroll memory") {
+ val (memoryStore, _) = makeMemoryStore(12000)
+ assert(memoryStore.currentUnrollMemory === 0)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
+
+ def reserveUnrollMemoryForThisTask(memory: Long): Boolean = {
+ memoryStore.reserveUnrollMemoryForThisTask(TestBlockId(""), memory)
+ }
+
+ // Reserve
+ assert(reserveUnrollMemoryForThisTask(100))
+ assert(memoryStore.currentUnrollMemoryForThisTask === 100)
+ assert(reserveUnrollMemoryForThisTask(200))
+ assert(memoryStore.currentUnrollMemoryForThisTask === 300)
+ assert(reserveUnrollMemoryForThisTask(500))
+ assert(memoryStore.currentUnrollMemoryForThisTask === 800)
+ assert(!reserveUnrollMemoryForThisTask(1000000))
+ assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted
+ // Release
+ memoryStore.releaseUnrollMemoryForThisTask(100)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 700)
+ memoryStore.releaseUnrollMemoryForThisTask(100)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 600)
+ // Reserve again
+ assert(reserveUnrollMemoryForThisTask(4400))
+ assert(memoryStore.currentUnrollMemoryForThisTask === 5000)
+ assert(!reserveUnrollMemoryForThisTask(20000))
+ assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted
+ // Release again
+ memoryStore.releaseUnrollMemoryForThisTask(1000)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 4000)
+ memoryStore.releaseUnrollMemoryForThisTask() // release all
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
+ }
+
+ test("safely unroll blocks") {
+ val smallList = List.fill(40)(new Array[Byte](100))
+ val bigList = List.fill(40)(new Array[Byte](1000))
+ val ct = implicitly[ClassTag[Array[Byte]]]
+ val (memoryStore, blockInfoManager) = makeMemoryStore(12000)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
+
+ def putIterator[T](
+ blockId: BlockId,
+ iter: Iterator[T],
+ classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = {
+ assert(blockInfoManager.lockNewBlockForWriting(
+ blockId,
+ new BlockInfo(StorageLevel.MEMORY_ONLY, classTag, tellMaster = false)))
+ val res = memoryStore.putIterator(blockId, iter, StorageLevel.MEMORY_ONLY, classTag)
+ blockInfoManager.unlock(blockId)
+ res
+ }
+
+ // Unroll with all the space in the world. This should succeed.
+ var putResult = putIterator("unroll", smallList.iterator, ClassTag.Any)
+ assert(putResult.isRight)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
+ smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
+ assert(e === a, "getValues() did not return original values!")
+ }
+ blockInfoManager.lockForWriting("unroll")
+ assert(memoryStore.remove("unroll"))
+ blockInfoManager.removeBlock("unroll")
+
+ // Unroll with not enough space. This should succeed after kicking out someBlock1.
+ assert(putIterator("someBlock1", smallList.iterator, ct).isRight)
+ assert(putIterator("someBlock2", smallList.iterator, ct).isRight)
+ putResult = putIterator("unroll", smallList.iterator, ClassTag.Any)
+ assert(putResult.isRight)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
+ assert(memoryStore.contains("someBlock2"))
+ assert(!memoryStore.contains("someBlock1"))
+ smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
+ assert(e === a, "getValues() did not return original values!")
+ }
+ blockInfoManager.lockForWriting("unroll")
+ assert(memoryStore.remove("unroll"))
+ blockInfoManager.removeBlock("unroll")
+
+ // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 =
+ // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator.
+ // In the meantime, however, we kicked out someBlock2 before giving up.
+ assert(putIterator("someBlock3", smallList.iterator, ct).isRight)
+ putResult = putIterator("unroll", bigList.iterator, ClassTag.Any)
+ assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
+ assert(!memoryStore.contains("someBlock2"))
+ assert(putResult.isLeft)
+ bigList.iterator.zip(putResult.left.get).foreach { case (e, a) =>
+ assert(e === a, "putIterator() did not return original values!")
+ }
+ // The unroll memory was freed once the iterator returned by putIterator() was fully traversed.
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
+ }
+
+ test("safely unroll blocks through putIterator") {
+ val (memoryStore, blockInfoManager) = makeMemoryStore(12000)
+ val smallList = List.fill(40)(new Array[Byte](100))
+ val bigList = List.fill(40)(new Array[Byte](1000))
+ def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]]
+ def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]]
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
+
+ def putIterator[T](
+ blockId: BlockId,
+ iter: Iterator[T],
+ classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = {
+ assert(blockInfoManager.lockNewBlockForWriting(
+ blockId,
+ new BlockInfo(StorageLevel.MEMORY_ONLY, classTag, tellMaster = false)))
+ val res = memoryStore.putIterator(blockId, iter, StorageLevel.MEMORY_ONLY, classTag)
+ blockInfoManager.unlock(blockId)
+ res
+ }
+
+ // Unroll with plenty of space. This should succeed and cache both blocks.
+ val result1 = putIterator("b1", smallIterator, ClassTag.Any)
+ val result2 = putIterator("b2", smallIterator, ClassTag.Any)
+ assert(memoryStore.contains("b1"))
+ assert(memoryStore.contains("b2"))
+ assert(result1.isRight) // unroll was successful
+ assert(result2.isRight)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
+
+ // Re-put these two blocks so block manager knows about them too. Otherwise, block manager
+ // would not know how to drop them from memory later.
+ blockInfoManager.lockForWriting("b1")
+ memoryStore.remove("b1")
+ blockInfoManager.removeBlock("b1")
+ blockInfoManager.lockForWriting("b2")
+ memoryStore.remove("b2")
+ blockInfoManager.removeBlock("b2")
+ putIterator("b1", smallIterator, ClassTag.Any)
+ putIterator("b2", smallIterator, ClassTag.Any)
+
+ // Unroll with not enough space. This should succeed but kick out b1 in the process.
+ val result3 = putIterator("b3", smallIterator, ClassTag.Any)
+ assert(result3.isRight)
+ assert(!memoryStore.contains("b1"))
+ assert(memoryStore.contains("b2"))
+ assert(memoryStore.contains("b3"))
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
+ blockInfoManager.lockForWriting("b3")
+ assert(memoryStore.remove("b3"))
+ blockInfoManager.removeBlock("b3")
+ putIterator("b3", smallIterator, ClassTag.Any)
+
+ // Unroll huge block with not enough space. This should fail and kick out b2 in the process.
+ val result4 = putIterator("b4", bigIterator, ClassTag.Any)
+ assert(result4.isLeft) // unroll was unsuccessful
+ assert(!memoryStore.contains("b1"))
+ assert(!memoryStore.contains("b2"))
+ assert(memoryStore.contains("b3"))
+ assert(!memoryStore.contains("b4"))
+ assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
+ result4.left.get.close()
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0) // close released the unroll memory
+ }
+
+ test("multiple unrolls by the same thread") {
+ val (memoryStore, _) = makeMemoryStore(12000)
+ val smallList = List.fill(40)(new Array[Byte](100))
+ def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]]
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
+
+ def putIterator(
+ blockId: BlockId,
+ iter: Iterator[Any]): Either[PartiallyUnrolledIterator[Any], Long] = {
+ memoryStore.putIterator(blockId, iter, StorageLevel.MEMORY_ONLY, ClassTag.Any)
+ }
+
+ // All unroll memory used is released because putIterator did not return an iterator
+ assert(putIterator("b1", smallIterator).isRight)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
+ assert(putIterator("b2", smallIterator).isRight)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
+
+ // Unroll memory is not released because putIterator returned an iterator
+ // that still depends on the underlying vector used in the process
+ assert(putIterator("b3", smallIterator).isLeft)
+ val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask
+ assert(unrollMemoryAfterB3 > 0)
+
+ // The unroll memory owned by this thread builds on top of its value after the previous unrolls
+ assert(putIterator("b4", smallIterator).isLeft)
+ val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask
+ assert(unrollMemoryAfterB4 > unrollMemoryAfterB3)
+
+ // ... but only to a certain extent (until we run out of free space to grant new unroll memory)
+ assert(putIterator("b5", smallIterator).isLeft)
+ val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask
+ assert(putIterator("b6", smallIterator).isLeft)
+ val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask
+ assert(putIterator("b7", smallIterator).isLeft)
+ val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask
+ assert(unrollMemoryAfterB5 === unrollMemoryAfterB4)
+ assert(unrollMemoryAfterB6 === unrollMemoryAfterB4)
+ assert(unrollMemoryAfterB7 === unrollMemoryAfterB4)
+ }
+
+ test("lazily create a big ByteBuffer to avoid OOM if it cannot be put into MemoryStore") {
+ val (memoryStore, blockInfoManager) = makeMemoryStore(12000)
+ val blockId = BlockId("rdd_3_10")
+ blockInfoManager.lockNewBlockForWriting(
+ blockId, new BlockInfo(StorageLevel.MEMORY_ONLY, ClassTag.Any, tellMaster = false))
+ memoryStore.putBytes(blockId, 13000, () => {
+ fail("A big ByteBuffer that cannot be put into MemoryStore should not be created")
+ })
+ }
+
+ test("put a small ByteBuffer to MemoryStore") {
+ val (memoryStore, _) = makeMemoryStore(12000)
+ val blockId = BlockId("rdd_3_10")
+ var bytes: ChunkedByteBuffer = null
+ memoryStore.putBytes(blockId, 10000, () => {
+ bytes = new ChunkedByteBuffer(ByteBuffer.allocate(10000))
+ bytes
+ })
+ assert(memoryStore.getSize(blockId) === 10000)
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index d85147e961..aa7fc2121e 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -67,6 +67,7 @@ public final class UnsafeExternalRowSorter {
sorter = UnsafeExternalSorter.create(
taskContext.taskMemoryManager(),
sparkEnv.blockManager(),
+ sparkEnv.serializerManager(),
taskContext,
new RowComparator(ordering, schema.length()),
prefixComparator,
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index acf6c583bb..8882903bbf 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -241,7 +241,11 @@ public final class UnsafeFixedWidthAggregationMap {
*/
public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException {
return new UnsafeKVExternalSorter(
- groupingKeySchema, aggregationBufferSchema,
- SparkEnv.get().blockManager(), map.getPageSizeBytes(), map);
+ groupingKeySchema,
+ aggregationBufferSchema,
+ SparkEnv.get().blockManager(),
+ SparkEnv.get().serializerManager(),
+ map.getPageSizeBytes(),
+ map);
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 9e08675c3e..d3bfb00b3f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -24,6 +24,7 @@ import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.TaskContext;
import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
@@ -52,14 +53,16 @@ public final class UnsafeKVExternalSorter {
StructType keySchema,
StructType valueSchema,
BlockManager blockManager,
+ SerializerManager serializerManager,
long pageSizeBytes) throws IOException {
- this(keySchema, valueSchema, blockManager, pageSizeBytes, null);
+ this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, null);
}
public UnsafeKVExternalSorter(
StructType keySchema,
StructType valueSchema,
BlockManager blockManager,
+ SerializerManager serializerManager,
long pageSizeBytes,
@Nullable BytesToBytesMap map) throws IOException {
this.keySchema = keySchema;
@@ -77,6 +80,7 @@ public final class UnsafeKVExternalSorter {
sorter = UnsafeExternalSorter.create(
taskMemoryManager,
blockManager,
+ serializerManager,
taskContext,
recordComparator,
prefixComparator,
@@ -116,6 +120,7 @@ public final class UnsafeKVExternalSorter {
sorter = UnsafeExternalSorter.createWithExistingInMemorySorter(
taskMemoryManager,
blockManager,
+ serializerManager,
taskContext,
new KVComparator(ordering, keySchema.length()),
prefixComparator,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index a4c0e1c9fb..270c09aff3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -339,6 +339,7 @@ case class Window(
sorter = UnsafeExternalSorter.create(
TaskContext.get().taskMemoryManager(),
SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
TaskContext.get(),
null,
null,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index c74ac8a282..233ac263aa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -399,6 +399,7 @@ private[sql] class DynamicPartitionWriterContainer(
sortingKeySchema,
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
TaskContext.get().taskMemoryManager().pageSizeBytes)
while (iterator.hasNext) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
index fabd2fbe1e..fb65b50da8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
@@ -41,6 +41,7 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField
val sorter = UnsafeExternalSorter.create(
context.taskMemoryManager(),
SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
context,
null,
null,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index e03bd6a3e7..476d93fc2a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -120,7 +120,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
metricsSystem = null))
val sorter = new UnsafeKVExternalSorter(
- keySchema, valueSchema, SparkEnv.get.blockManager, pageSize)
+ keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, pageSize)
// Insert the keys and values into the sorter
inputData.foreach { case (k, v) =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
index a29d55ee25..794fe264ea 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -279,6 +279,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer(
StructType.fromAttributes(partitionOutput),
StructType.fromAttributes(dataOutput),
SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
TaskContext.get().taskMemoryManager().pageSizeBytes)
while (iterator.hasNext) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
index ace67a639c..c56520b1e2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
@@ -115,6 +115,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
assertValid()
val hadoopConf = broadcastedHadoopConf.value
val blockManager = SparkEnv.get.blockManager
+ val serializerManager = SparkEnv.get.serializerManager
val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition]
val blockId = partition.blockId
@@ -161,7 +162,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
logDebug(s"Stored partition data of $this into block manager with level $storageLevel")
dataRead.rewind()
}
- blockManager.dataDeserialize(blockId, new ChunkedByteBuffer(dataRead))
+ serializerManager.dataDeserialize(blockId, new ChunkedByteBuffer(dataRead))
.asInstanceOf[Iterator[T]]
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
index 6d4f4b99c1..85350ff658 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
+import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage._
import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._
import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils}
@@ -123,6 +124,7 @@ private[streaming] case class WriteAheadLogBasedStoreResult(
*/
private[streaming] class WriteAheadLogBasedBlockHandler(
blockManager: BlockManager,
+ serializerManager: SerializerManager,
streamId: Int,
storageLevel: StorageLevel,
conf: SparkConf,
@@ -173,10 +175,10 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
val serializedBlock = block match {
case ArrayBufferBlock(arrayBuffer) =>
numRecords = Some(arrayBuffer.size.toLong)
- blockManager.dataSerialize(blockId, arrayBuffer.iterator)
+ serializerManager.dataSerialize(blockId, arrayBuffer.iterator)
case IteratorBlock(iterator) =>
val countIterator = new CountingIterator(iterator)
- val serializedBlock = blockManager.dataSerialize(blockId, countIterator)
+ val serializedBlock = serializerManager.dataSerialize(blockId, countIterator)
numRecords = countIterator.count
serializedBlock
case ByteBufferBlock(byteBuffer) =>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index e41fd11963..4fb0f8caac 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -60,7 +60,7 @@ private[streaming] class ReceiverSupervisorImpl(
"Please use streamingContext.checkpoint() to set the checkpoint directory. " +
"See documentation for more details.")
}
- new WriteAheadLogBasedBlockHandler(env.blockManager, receiver.streamId,
+ new WriteAheadLogBasedBlockHandler(env.blockManager, env.serializerManager, receiver.streamId,
receiver.storageLevel, env.conf, hadoopConf, checkpointDirOption.get)
} else {
new BlockManagerBasedBlockHandler(env.blockManager, receiver.storageLevel)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index 122ca0627f..4e77cd6347 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -60,6 +60,7 @@ class ReceivedBlockHandlerSuite
val mapOutputTracker = new MapOutputTrackerMaster(conf)
val shuffleManager = new HashShuffleManager(conf)
val serializer = new KryoSerializer(conf)
+ var serializerManager = new SerializerManager(serializer, conf)
val manualClock = new ManualClock
val blockManagerSize = 10000000
val blockManagerBuffer = new ArrayBuffer[BlockManager]()
@@ -156,7 +157,7 @@ class ReceivedBlockHandlerSuite
val reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf)
val bytes = reader.read(fileSegment)
reader.close()
- blockManager.dataDeserialize(generateBlockId(), new ChunkedByteBuffer(bytes)).toList
+ serializerManager.dataDeserialize(generateBlockId(), new ChunkedByteBuffer(bytes)).toList
}
loggedData shouldEqual data
}
@@ -265,7 +266,6 @@ class ReceivedBlockHandlerSuite
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
- val serializerManager = new SerializerManager(serializer, conf)
val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializerManager, conf,
memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
memManager.setMemoryStore(blockManager.memoryStore)
@@ -335,7 +335,8 @@ class ReceivedBlockHandlerSuite
}
}
- def dataToByteBuffer(b: Seq[String]) = blockManager.dataSerialize(generateBlockId, b.iterator)
+ def dataToByteBuffer(b: Seq[String]) =
+ serializerManager.dataSerialize(generateBlockId, b.iterator)
val blocks = data.grouped(10).toSeq
@@ -367,8 +368,8 @@ class ReceivedBlockHandlerSuite
/** Instantiate a WriteAheadLogBasedBlockHandler and run a code with it */
private def withWriteAheadLogBasedBlockHandler(body: WriteAheadLogBasedBlockHandler => Unit) {
require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = false) === 1)
- val receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, 1,
- storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock)
+ val receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, serializerManager,
+ 1, storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock)
try {
body(receivedBlockHandler)
} finally {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
index c4bf42d0f2..ce5a6e00fb 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
+import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId}
import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter}
import org.apache.spark.util.Utils
@@ -39,6 +40,7 @@ class WriteAheadLogBackedBlockRDDSuite
var sparkContext: SparkContext = null
var blockManager: BlockManager = null
+ var serializerManager: SerializerManager = null
var dir: File = null
override def beforeEach(): Unit = {
@@ -58,6 +60,7 @@ class WriteAheadLogBackedBlockRDDSuite
super.beforeAll()
sparkContext = new SparkContext(conf)
blockManager = sparkContext.env.blockManager
+ serializerManager = sparkContext.env.serializerManager
}
override def afterAll(): Unit = {
@@ -65,6 +68,8 @@ class WriteAheadLogBackedBlockRDDSuite
try {
sparkContext.stop()
System.clearProperty("spark.driver.port")
+ blockManager = null
+ serializerManager = null
} finally {
super.afterAll()
}
@@ -107,8 +112,6 @@ class WriteAheadLogBackedBlockRDDSuite
* It can also test if the partitions that were read from the log were again stored in
* block manager.
*
- *
- *
* @param numPartitions Number of partitions in RDD
* @param numPartitionsInBM Number of partitions to write to the BlockManager.
* Partitions 0 to (numPartitionsInBM-1) will be written to BlockManager
@@ -223,7 +226,7 @@ class WriteAheadLogBackedBlockRDDSuite
require(blockData.size === blockIds.size)
val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf)
val segments = blockData.zip(blockIds).map { case (data, id) =>
- writer.write(blockManager.dataSerialize(id, data.iterator).toByteBuffer)
+ writer.write(serializerManager.dataSerialize(id, data.iterator).toByteBuffer)
}
writer.close()
segments