aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala229
-rw-r--r--core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala11
4 files changed, 248 insertions, 43 deletions
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index eebb43e245..30d2e23efd 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -746,7 +746,7 @@ private[spark] class BlockManager(
// We will drop it to disk later if the memory store can't hold it.
val putSucceeded = if (level.deserialized) {
val values = serializerManager.dataDeserialize(blockId, bytes)(classTag)
- memoryStore.putIterator(blockId, values, level, classTag) match {
+ memoryStore.putIteratorAsValues(blockId, values, classTag) match {
case Right(_) => true
case Left(iter) =>
// If putting deserialized values in memory failed, we will put the bytes directly to
@@ -876,21 +876,40 @@ private[spark] class BlockManager(
if (level.useMemory) {
// Put it in memory first, even if it also has useDisk set to true;
// We will drop it to disk later if the memory store can't hold it.
- memoryStore.putIterator(blockId, iterator(), level, classTag) match {
- case Right(s) =>
- size = s
- case Left(iter) =>
- // Not enough space to unroll this block; drop to disk if applicable
- if (level.useDisk) {
- logWarning(s"Persisting block $blockId to disk instead.")
- diskStore.put(blockId) { fileOutputStream =>
- serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag)
+ if (level.deserialized) {
+ memoryStore.putIteratorAsValues(blockId, iterator(), classTag) match {
+ case Right(s) =>
+ size = s
+ case Left(iter) =>
+ // Not enough space to unroll this block; drop to disk if applicable
+ if (level.useDisk) {
+ logWarning(s"Persisting block $blockId to disk instead.")
+ diskStore.put(blockId) { fileOutputStream =>
+ serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag)
+ }
+ size = diskStore.getSize(blockId)
+ } else {
+ iteratorFromFailedMemoryStorePut = Some(iter)
}
- size = diskStore.getSize(blockId)
- } else {
- iteratorFromFailedMemoryStorePut = Some(iter)
- }
+ }
+ } else { // !level.deserialized
+ memoryStore.putIteratorAsBytes(blockId, iterator(), classTag) match {
+ case Right(s) =>
+ size = s
+ case Left(partiallySerializedValues) =>
+ // Not enough space to unroll this block; drop to disk if applicable
+ if (level.useDisk) {
+ logWarning(s"Persisting block $blockId to disk instead.")
+ diskStore.put(blockId) { fileOutputStream =>
+ partiallySerializedValues.finishWritingToStream(fileOutputStream)
+ }
+ size = diskStore.getSize(blockId)
+ } else {
+ iteratorFromFailedMemoryStorePut = Some(partiallySerializedValues.valuesIterator)
+ }
+ }
}
+
} else if (level.useDisk) {
diskStore.put(blockId) { fileOutputStream =>
serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag)
@@ -991,7 +1010,7 @@ private[spark] class BlockManager(
// Note: if we had a means to discard the disk iterator, we would do that here.
memoryStore.getValues(blockId).get
} else {
- memoryStore.putIterator(blockId, diskIterator, level, classTag) match {
+ memoryStore.putIteratorAsValues(blockId, diskIterator, classTag) match {
case Left(iter) =>
// The memory store put() failed, so it returned the iterator back to us:
iter
diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index 90016cbeb8..1a78c9c010 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -17,20 +17,24 @@
package org.apache.spark.storage.memory
+import java.io.OutputStream
+import java.nio.ByteBuffer
import java.util.LinkedHashMap
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
+import com.google.common.io.ByteStreams
+
import org.apache.spark.{SparkConf, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.memory.MemoryManager
-import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.serializer.{SerializationStream, SerializerManager}
import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
import org.apache.spark.util.collection.SizeTrackingVector
-import org.apache.spark.util.io.ChunkedByteBuffer
+import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer}
private sealed trait MemoryEntry[T] {
def size: Long
@@ -42,8 +46,9 @@ private case class DeserializedMemoryEntry[T](
classTag: ClassTag[T]) extends MemoryEntry[T]
private case class SerializedMemoryEntry[T](
buffer: ChunkedByteBuffer,
- size: Long,
- classTag: ClassTag[T]) extends MemoryEntry[T]
+ classTag: ClassTag[T]) extends MemoryEntry[T] {
+ def size: Long = buffer.size
+}
private[storage] trait BlockEvictionHandler {
/**
@@ -132,7 +137,7 @@ private[spark] class MemoryStore(
// We acquired enough memory for the block, so go ahead and put it
val bytes = _bytes()
assert(bytes.size == size)
- val entry = new SerializedMemoryEntry[T](bytes, size, implicitly[ClassTag[T]])
+ val entry = new SerializedMemoryEntry[T](bytes, implicitly[ClassTag[T]])
entries.synchronized {
entries.put(blockId, entry)
}
@@ -145,7 +150,7 @@ private[spark] class MemoryStore(
}
/**
- * Attempt to put the given block in memory store.
+ * Attempt to put the given block in memory store as values.
*
* It's possible that the iterator is too large to materialize and store in memory. To avoid
* OOM exceptions, this method will gradually unroll the iterator while periodically checking
@@ -160,10 +165,9 @@ private[spark] class MemoryStore(
* iterator or call `close()` on it in order to free the storage memory consumed by the
* partially-unrolled block.
*/
- private[storage] def putIterator[T](
+ private[storage] def putIteratorAsValues[T](
blockId: BlockId,
values: Iterator[T],
- level: StorageLevel,
classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = {
require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")
@@ -218,12 +222,8 @@ private[spark] class MemoryStore(
// We successfully unrolled the entirety of this block
val arrayValues = vector.toArray
vector = null
- val entry = if (level.deserialized) {
+ val entry =
new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag)
- } else {
- val bytes = serializerManager.dataSerialize(blockId, arrayValues.iterator)(classTag)
- new SerializedMemoryEntry[T](bytes, bytes.size, classTag)
- }
val size = entry.size
def transferUnrollToStorage(amount: Long): Unit = {
// Synchronize so that transfer is atomic
@@ -255,12 +255,8 @@ private[spark] class MemoryStore(
entries.synchronized {
entries.put(blockId, entry)
}
- val bytesOrValues = if (level.deserialized) "values" else "bytes"
- logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format(
- blockId,
- bytesOrValues,
- Utils.bytesToString(size),
- Utils.bytesToString(maxMemory - blocksMemoryUsed)))
+ logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format(
+ blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed)))
Right(size)
} else {
assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask,
@@ -279,13 +275,117 @@ private[spark] class MemoryStore(
}
}
+ /**
+ * Attempt to put the given block in memory store as bytes.
+ *
+ * It's possible that the iterator is too large to materialize and store in memory. To avoid
+ * OOM exceptions, this method will gradually unroll the iterator while periodically checking
+ * whether there is enough free memory. If the block is successfully materialized, then the
+ * temporary unroll memory used during the materialization is "transferred" to storage memory,
+ * so we won't acquire more memory than is actually needed to store the block.
+ *
+ * @return in case of success, the estimated the estimated size of the stored data. In case of
+ * failure, return a handle which allows the caller to either finish the serialization
+ * by spilling to disk or to deserialize the partially-serialized block and reconstruct
+ * the original input iterator. The caller must either fully consume this result
+ * iterator or call `discard()` on it in order to free the storage memory consumed by the
+ * partially-unrolled block.
+ */
+ private[storage] def putIteratorAsBytes[T](
+ blockId: BlockId,
+ values: Iterator[T],
+ classTag: ClassTag[T]): Either[PartiallySerializedBlock[T], Long] = {
+
+ require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")
+
+ // Whether there is still enough memory for us to continue unrolling this block
+ var keepUnrolling = true
+ // Initial per-task memory to request for unrolling blocks (bytes).
+ val initialMemoryThreshold = unrollMemoryThreshold
+ // Keep track of unroll memory used by this particular block / putIterator() operation
+ var unrollMemoryUsedByThisBlock = 0L
+ // Underlying buffer for unrolling the block
+ val redirectableStream = new RedirectableOutputStream
+ val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(initialMemoryThreshold.toInt)
+ redirectableStream.setOutputStream(byteArrayChunkOutputStream)
+ val serializationStream: SerializationStream = {
+ val ser = serializerManager.getSerializer(classTag).newInstance()
+ ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream))
+ }
+
+ // Request enough memory to begin unrolling
+ keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold)
+
+ if (!keepUnrolling) {
+ logWarning(s"Failed to reserve initial memory threshold of " +
+ s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
+ } else {
+ unrollMemoryUsedByThisBlock += initialMemoryThreshold
+ }
+
+ def reserveAdditionalMemoryIfNecessary(): Unit = {
+ if (byteArrayChunkOutputStream.size > unrollMemoryUsedByThisBlock) {
+ val amountToRequest = byteArrayChunkOutputStream.size - unrollMemoryUsedByThisBlock
+ keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest)
+ if (keepUnrolling) {
+ unrollMemoryUsedByThisBlock += amountToRequest
+ }
+ }
+ }
+
+ // Unroll this block safely, checking whether we have exceeded our threshold
+ while (values.hasNext && keepUnrolling) {
+ serializationStream.writeObject(values.next())(classTag)
+ reserveAdditionalMemoryIfNecessary()
+ }
+
+ // Make sure that we have enough memory to store the block. By this point, it is possible that
+ // the block's actual memory usage has exceeded the unroll memory by a small amount, so we
+ // perform one final call to attempt to allocate additional memory if necessary.
+ if (keepUnrolling) {
+ serializationStream.close()
+ reserveAdditionalMemoryIfNecessary()
+ }
+
+ if (keepUnrolling) {
+ val entry = SerializedMemoryEntry[T](
+ new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)), classTag)
+ // Synchronize so that transfer is atomic
+ memoryManager.synchronized {
+ releaseUnrollMemoryForThisTask(unrollMemoryUsedByThisBlock)
+ val success = memoryManager.acquireStorageMemory(blockId, entry.size)
+ assert(success, "transferring unroll memory to storage memory failed")
+ }
+ entries.synchronized {
+ entries.put(blockId, entry)
+ }
+ logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format(
+ blockId, Utils.bytesToString(entry.size), Utils.bytesToString(blocksMemoryUsed)))
+ Right(entry.size)
+ } else {
+ // We ran out of space while unrolling the values for this block
+ logUnrollFailureMessage(blockId, byteArrayChunkOutputStream.size)
+ Left(
+ new PartiallySerializedBlock(
+ this,
+ serializerManager,
+ blockId,
+ serializationStream,
+ redirectableStream,
+ unrollMemoryUsedByThisBlock,
+ new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)),
+ values,
+ classTag))
+ }
+ }
+
def getBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
val entry = entries.synchronized { entries.get(blockId) }
entry match {
case null => None
case e: DeserializedMemoryEntry[_] =>
throw new IllegalArgumentException("should only call getBytes on serialized blocks")
- case SerializedMemoryEntry(bytes, _, _) => Some(bytes)
+ case SerializedMemoryEntry(bytes, _) => Some(bytes)
}
}
@@ -373,7 +473,7 @@ private[spark] class MemoryStore(
def dropBlock[T](blockId: BlockId, entry: MemoryEntry[T]): Unit = {
val data = entry match {
case DeserializedMemoryEntry(values, _, _) => Left(values)
- case SerializedMemoryEntry(buffer, _, _) => Right(buffer)
+ case SerializedMemoryEntry(buffer, _) => Right(buffer)
}
val newEffectiveStorageLevel =
blockEvictionHandler.dropFromMemory(blockId, () => data)(entry.classTag)
@@ -507,12 +607,13 @@ private[spark] class MemoryStore(
}
/**
- * The result of a failed [[MemoryStore.putIterator()]] call.
+ * The result of a failed [[MemoryStore.putIteratorAsValues()]] call.
*
- * @param memoryStore the memoryStore, used for freeing memory.
+ * @param memoryStore the memoryStore, used for freeing memory.
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
- * @param unrolled an iterator for the partially-unrolled values.
- * @param rest the rest of the original iterator passed to [[MemoryStore.putIterator()]].
+ * @param unrolled an iterator for the partially-unrolled values.
+ * @param rest the rest of the original iterator passed to
+ * [[MemoryStore.putIteratorAsValues()]].
*/
private[storage] class PartiallyUnrolledIterator[T](
memoryStore: MemoryStore,
@@ -544,3 +645,81 @@ private[storage] class PartiallyUnrolledIterator[T](
iter = null
}
}
+
+/**
+ * A wrapper which allows an open [[OutputStream]] to be redirected to a different sink.
+ */
+private class RedirectableOutputStream extends OutputStream {
+ private[this] var os: OutputStream = _
+ def setOutputStream(s: OutputStream): Unit = { os = s }
+ override def write(b: Int): Unit = os.write(b)
+ override def write(b: Array[Byte]): Unit = os.write(b)
+ override def write(b: Array[Byte], off: Int, len: Int): Unit = os.write(b, off, len)
+ override def flush(): Unit = os.flush()
+ override def close(): Unit = os.close()
+}
+
+/**
+ * The result of a failed [[MemoryStore.putIteratorAsBytes()]] call.
+ *
+ * @param memoryStore the MemoryStore, used for freeing memory.
+ * @param serializerManager the SerializerManager, used for deserializing values.
+ * @param blockId the block id.
+ * @param serializationStream a serialization stream which writes to [[redirectableOutputStream]].
+ * @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
+ * @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
+ * @param unrolled a byte buffer containing the partially-serialized values.
+ * @param rest the rest of the original iterator passed to
+ * [[MemoryStore.putIteratorAsValues()]].
+ * @param classTag the [[ClassTag]] for the block.
+ */
+private[storage] class PartiallySerializedBlock[T](
+ memoryStore: MemoryStore,
+ serializerManager: SerializerManager,
+ blockId: BlockId,
+ serializationStream: SerializationStream,
+ redirectableOutputStream: RedirectableOutputStream,
+ unrollMemory: Long,
+ unrolled: ChunkedByteBuffer,
+ rest: Iterator[T],
+ classTag: ClassTag[T]) {
+
+ /**
+ * Called to dispose of this block and free its memory.
+ */
+ def discard(): Unit = {
+ try {
+ serializationStream.close()
+ } finally {
+ memoryStore.releaseUnrollMemoryForThisTask(unrollMemory)
+ }
+ }
+
+ /**
+ * Finish writing this block to the given output stream by first writing the serialized values
+ * and then serializing the values from the original input iterator.
+ */
+ def finishWritingToStream(os: OutputStream): Unit = {
+ ByteStreams.copy(unrolled.toInputStream(), os)
+ redirectableOutputStream.setOutputStream(os)
+ while (rest.hasNext) {
+ serializationStream.writeObject(rest.next())(classTag)
+ }
+ discard()
+ }
+
+ /**
+ * Returns an iterator over the values in this block by first deserializing the serialized
+ * values and then consuming the rest of the original input iterator.
+ *
+ * If the caller does not plan to fully consume the resulting iterator then they must call
+ * `close()` on it to free its resources.
+ */
+ def valuesIterator: PartiallyUnrolledIterator[T] = {
+ new PartiallyUnrolledIterator(
+ memoryStore,
+ unrollMemory,
+ unrolled = serializerManager.dataDeserialize(blockId, unrolled)(classTag),
+ rest = rest)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
index 8527e3ae69..09e7579ae9 100644
--- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
@@ -27,6 +27,8 @@ private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutp
def this() = this(32)
+ def getCount(): Int = count
+
def toByteBuffer: ByteBuffer = {
return ByteBuffer.wrap(buf, 0, count)
}
diff --git a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
index daac6f971e..16fe3be303 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
@@ -30,10 +30,10 @@ import scala.collection.mutable.ArrayBuffer
private[spark]
class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
- private val chunks = new ArrayBuffer[Array[Byte]]
+ private[this] val chunks = new ArrayBuffer[Array[Byte]]
/** Index of the last chunk. Starting with -1 when the chunks array is empty. */
- private var lastChunkIndex = -1
+ private[this] var lastChunkIndex = -1
/**
* Next position to write in the last chunk.
@@ -41,12 +41,16 @@ class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
* If this equals chunkSize, it means for next write we need to allocate a new chunk.
* This can also never be 0.
*/
- private var position = chunkSize
+ private[this] var position = chunkSize
+ private[this] var _size = 0
+
+ def size: Long = _size
override def write(b: Int): Unit = {
allocateNewChunkIfNeeded()
chunks(lastChunkIndex)(position) = b.toByte
position += 1
+ _size += 1
}
override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
@@ -58,6 +62,7 @@ class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
written += thisBatch
position += thisBatch
}
+ _size += len
}
@inline