aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2017-03-29 20:27:41 +0800
committerWenchen Fan <wenchen@databricks.com>2017-03-29 20:27:41 +0800
commitb56ad2b1ec19fd60fa9d4926d12244fd3f56aca4 (patch)
treeda4c6117196cbcccd8f94469c7ed322aef474ca8
parent9712bd3954c029de5c828f27b57d46e4a6325a38 (diff)
downloadspark-b56ad2b1ec19fd60fa9d4926d12244fd3f56aca4.tar.gz
spark-b56ad2b1ec19fd60fa9d4926d12244fd3f56aca4.tar.bz2
spark-b56ad2b1ec19fd60fa9d4926d12244fd3f56aca4.zip
[SPARK-19556][CORE] Do not encrypt block manager data in memory.
This change modifies the way block data is encrypted to make the more common cases faster, while penalizing an edge case. As a side effect of the change, all data that goes through the block manager is now encrypted only when needed, including the previous path (broadcast variables) where that did not happen. The way the change works is by not encrypting data that is stored in memory; so if a serialized block is in memory, it will only be encrypted once it is evicted to disk. The penalty comes when transferring that encrypted data from disk. If the data ends up in memory again, it is as efficient as before; but if the evicted block needs to be transferred directly to a remote executor, then there's now a performance penalty, since the code now uses a custom FileRegion implementation to decrypt the data before transferring. This also means that block data transferred between executors now is not encrypted (and thus relies on the network library encryption support for secrecy). Shuffle blocks are still transferred in encrypted form, since they're handled in a slightly different way by the code. This also keeps compatibility with existing external shuffle services, which transfer encrypted shuffle blocks, and avoids having to make the external service aware of encryption at all. The serialization and deserialization APIs in the SerializerManager now do not do encryption automatically; callers need to explicitly wrap their streams with an appropriate crypto stream before using those. As a result of these changes, some of the workarounds added in SPARK-19520 are removed here. Testing: a new trait ("EncryptionFunSuite") was added that provides an easy way to run a test twice, with encryption on and off; broadcast, block manager and caching tests were modified to use this new trait so that the existing tests exercise both encrypted and non-encrypted paths. I also ran some applications with encryption turned on to verify that they still work, including streaming tests that failed without the fix for SPARK-19520. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #17295 from vanzin/SPARK-19556.
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java15
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala87
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala172
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskStore.scala236
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageUtils.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala12
-rw-r--r--core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala16
-rw-r--r--core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala46
-rw-r--r--core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala39
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala77
-rw-r--r--core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala115
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala6
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala11
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala5
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala3
20 files changed, 710 insertions, 270 deletions
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java
index f3eaf22c01..51d7fda0cb 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -18,9 +18,11 @@
package org.apache.spark.network.util;
import java.io.Closeable;
+import java.io.EOFException;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
+import java.nio.channels.ReadableByteChannel;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
@@ -344,4 +346,17 @@ public class JavaUtils {
}
}
+ /**
+ * Fills a buffer with data read from the channel.
+ */
+ public static void readFully(ReadableByteChannel channel, ByteBuffer dst) throws IOException {
+ int expected = dst.remaining();
+ while (dst.hasRemaining()) {
+ if (channel.read(dst) < 0) {
+ throw new EOFException(String.format("Not enough bytes in channel (expected %d).",
+ expected));
+ }
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 22d01c47e6..039df75ce7 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -29,7 +29,7 @@ import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.Serializer
-import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel}
+import org.apache.spark.storage._
import org.apache.spark.util.{ByteBufferInputStream, Utils}
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
@@ -141,10 +141,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
}
/** Fetch torrent blocks from the driver and/or other executors. */
- private def readBlocks(): Array[ChunkedByteBuffer] = {
+ private def readBlocks(): Array[BlockData] = {
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
// to the driver, so other executors can pull these chunks from this executor as well.
- val blocks = new Array[ChunkedByteBuffer](numBlocks)
+ val blocks = new Array[BlockData](numBlocks)
val bm = SparkEnv.get.blockManager
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
@@ -173,7 +173,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
- blocks(pid) = b
+ blocks(pid) = new ByteBufferBlockData(b, true)
case None =>
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
}
@@ -219,18 +219,22 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
case None =>
logInfo("Started reading broadcast variable " + id)
val startTimeMs = System.currentTimeMillis()
- val blocks = readBlocks().flatMap(_.getChunks())
+ val blocks = readBlocks()
logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
- val obj = TorrentBroadcast.unBlockifyObject[T](
- blocks, SparkEnv.get.serializer, compressionCodec)
- // Store the merged copy in BlockManager so other tasks on this executor don't
- // need to re-fetch it.
- val storageLevel = StorageLevel.MEMORY_AND_DISK
- if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
- throw new SparkException(s"Failed to store $broadcastId in BlockManager")
+ try {
+ val obj = TorrentBroadcast.unBlockifyObject[T](
+ blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
+ // Store the merged copy in BlockManager so other tasks on this executor don't
+ // need to re-fetch it.
+ val storageLevel = StorageLevel.MEMORY_AND_DISK
+ if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
+ throw new SparkException(s"Failed to store $broadcastId in BlockManager")
+ }
+ obj
+ } finally {
+ blocks.foreach(_.dispose())
}
- obj
}
}
}
@@ -277,12 +281,11 @@ private object TorrentBroadcast extends Logging {
}
def unBlockifyObject[T: ClassTag](
- blocks: Array[ByteBuffer],
+ blocks: Array[InputStream],
serializer: Serializer,
compressionCodec: Option[CompressionCodec]): T = {
require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
- val is = new SequenceInputStream(
- blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration)
+ val is = new SequenceInputStream(blocks.iterator.asJavaEnumeration)
val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
val ser = serializer.newInstance()
val serIn = ser.deserializeStream(in)
diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
index cdd3b8d851..78dabb42ac 100644
--- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
+++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
@@ -16,20 +16,23 @@
*/
package org.apache.spark.security
-import java.io.{InputStream, OutputStream}
+import java.io.{EOFException, InputStream, OutputStream}
+import java.nio.ByteBuffer
+import java.nio.channels.{ReadableByteChannel, WritableByteChannel}
import java.util.Properties
import javax.crypto.KeyGenerator
import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}
import scala.collection.JavaConverters._
+import com.google.common.io.ByteStreams
import org.apache.commons.crypto.random._
import org.apache.commons.crypto.stream._
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
-import org.apache.spark.network.util.CryptoUtils
+import org.apache.spark.network.util.{CryptoUtils, JavaUtils}
/**
* A util class for manipulating IO encryption and decryption streams.
@@ -48,12 +51,27 @@ private[spark] object CryptoStreamUtils extends Logging {
os: OutputStream,
sparkConf: SparkConf,
key: Array[Byte]): OutputStream = {
- val properties = toCryptoConf(sparkConf)
- val iv = createInitializationVector(properties)
+ val params = new CryptoParams(key, sparkConf)
+ val iv = createInitializationVector(params.conf)
os.write(iv)
- val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
- new CryptoOutputStream(transformationStr, properties, os,
- new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
+ new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec,
+ new IvParameterSpec(iv))
+ }
+
+ /**
+ * Wrap a `WritableByteChannel` for encryption.
+ */
+ def createWritableChannel(
+ channel: WritableByteChannel,
+ sparkConf: SparkConf,
+ key: Array[Byte]): WritableByteChannel = {
+ val params = new CryptoParams(key, sparkConf)
+ val iv = createInitializationVector(params.conf)
+ val helper = new CryptoHelperChannel(channel)
+
+ helper.write(ByteBuffer.wrap(iv))
+ new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec,
+ new IvParameterSpec(iv))
}
/**
@@ -63,12 +81,27 @@ private[spark] object CryptoStreamUtils extends Logging {
is: InputStream,
sparkConf: SparkConf,
key: Array[Byte]): InputStream = {
- val properties = toCryptoConf(sparkConf)
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
- is.read(iv, 0, iv.length)
- val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
- new CryptoInputStream(transformationStr, properties, is,
- new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
+ ByteStreams.readFully(is, iv)
+ val params = new CryptoParams(key, sparkConf)
+ new CryptoInputStream(params.transformation, params.conf, is, params.keySpec,
+ new IvParameterSpec(iv))
+ }
+
+ /**
+ * Wrap a `ReadableByteChannel` for decryption.
+ */
+ def createReadableChannel(
+ channel: ReadableByteChannel,
+ sparkConf: SparkConf,
+ key: Array[Byte]): ReadableByteChannel = {
+ val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
+ val buf = ByteBuffer.wrap(iv)
+ JavaUtils.readFully(channel, buf)
+
+ val params = new CryptoParams(key, sparkConf)
+ new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec,
+ new IvParameterSpec(iv))
}
def toCryptoConf(conf: SparkConf): Properties = {
@@ -102,4 +135,34 @@ private[spark] object CryptoStreamUtils extends Logging {
}
iv
}
+
+ /**
+ * This class is a workaround for CRYPTO-125, that forces all bytes to be written to the
+ * underlying channel. Since the callers of this API are using blocking I/O, there are no
+ * concerns with regards to CPU usage here.
+ */
+ private class CryptoHelperChannel(sink: WritableByteChannel) extends WritableByteChannel {
+
+ override def write(src: ByteBuffer): Int = {
+ val count = src.remaining()
+ while (src.hasRemaining()) {
+ sink.write(src)
+ }
+ count
+ }
+
+ override def isOpen(): Boolean = sink.isOpen()
+
+ override def close(): Unit = sink.close()
+
+ }
+
+ private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) {
+
+ val keySpec = new SecretKeySpec(key, "AES")
+ val transformation = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
+ val conf = toCryptoConf(sparkConf)
+
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 96b288b9cf..bb7ed8709b 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -148,14 +148,14 @@ private[spark] class SerializerManager(
/**
* Wrap an output stream for compression if block compression is enabled for its block type
*/
- private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
+ def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
}
/**
* Wrap an input stream for compression if block compression is enabled for its block type
*/
- private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
+ def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}
@@ -167,30 +167,26 @@ private[spark] class SerializerManager(
val byteStream = new BufferedOutputStream(outputStream)
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val ser = getSerializer(implicitly[ClassTag[T]], autoPick).newInstance()
- ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
}
/** Serializes into a chunked byte buffer. */
def dataSerialize[T: ClassTag](
blockId: BlockId,
- values: Iterator[T],
- allowEncryption: Boolean = true): ChunkedByteBuffer = {
- dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]],
- allowEncryption = allowEncryption)
+ values: Iterator[T]): ChunkedByteBuffer = {
+ dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]])
}
/** Serializes into a chunked byte buffer. */
def dataSerializeWithExplicitClassTag(
blockId: BlockId,
values: Iterator[_],
- classTag: ClassTag[_],
- allowEncryption: Boolean = true): ChunkedByteBuffer = {
+ classTag: ClassTag[_]): ChunkedByteBuffer = {
val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
val byteStream = new BufferedOutputStream(bbos)
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val ser = getSerializer(classTag, autoPick).newInstance()
- val encrypted = if (allowEncryption) wrapForEncryption(byteStream) else byteStream
- ser.serializeStream(wrapForCompression(blockId, encrypted)).writeAll(values).close()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
bbos.toChunkedByteBuffer
}
@@ -200,15 +196,13 @@ private[spark] class SerializerManager(
*/
def dataDeserializeStream[T](
blockId: BlockId,
- inputStream: InputStream,
- maybeEncrypted: Boolean = true)
+ inputStream: InputStream)
(classTag: ClassTag[T]): Iterator[T] = {
val stream = new BufferedInputStream(inputStream)
val autoPick = !blockId.isInstanceOf[StreamBlockId]
- val decrypted = if (maybeEncrypted) wrapForEncryption(inputStream) else inputStream
getSerializer(classTag, autoPick)
.newInstance()
- .deserializeStream(wrapForCompression(blockId, decrypted))
+ .deserializeStream(wrapForCompression(blockId, inputStream))
.asIterator.asInstanceOf[Iterator[T]]
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 991346a40a..fcda9fa653 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -19,6 +19,7 @@ package org.apache.spark.storage
import java.io._
import java.nio.ByteBuffer
+import java.nio.channels.Channels
import scala.collection.mutable
import scala.collection.mutable.HashMap
@@ -35,7 +36,7 @@ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
import org.apache.spark.internal.Logging
import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.network._
-import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer}
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
@@ -56,6 +57,55 @@ private[spark] class BlockResult(
val bytes: Long)
/**
+ * Abstracts away how blocks are stored and provides different ways to read the underlying block
+ * data. Callers should call [[dispose()]] when they're done with the block.
+ */
+private[spark] trait BlockData {
+
+ def toInputStream(): InputStream
+
+ /**
+ * Returns a Netty-friendly wrapper for the block's data.
+ *
+ * @see [[ManagedBuffer#convertToNetty()]]
+ */
+ def toNetty(): Object
+
+ def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer
+
+ def toByteBuffer(): ByteBuffer
+
+ def size: Long
+
+ def dispose(): Unit
+
+}
+
+private[spark] class ByteBufferBlockData(
+ val buffer: ChunkedByteBuffer,
+ val shouldDispose: Boolean) extends BlockData {
+
+ override def toInputStream(): InputStream = buffer.toInputStream(dispose = false)
+
+ override def toNetty(): Object = buffer.toNetty
+
+ override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = {
+ buffer.copy(allocator)
+ }
+
+ override def toByteBuffer(): ByteBuffer = buffer.toByteBuffer
+
+ override def size: Long = buffer.size
+
+ override def dispose(): Unit = {
+ if (shouldDispose) {
+ buffer.dispose()
+ }
+ }
+
+}
+
+/**
* Manager running on every node (driver and executors) which provides interfaces for putting and
* retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap).
*
@@ -94,7 +144,7 @@ private[spark] class BlockManager(
// Actual storage of where blocks are kept
private[spark] val memoryStore =
new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this)
- private[spark] val diskStore = new DiskStore(conf, diskBlockManager)
+ private[spark] val diskStore = new DiskStore(conf, diskBlockManager, securityManager)
memoryManager.setMemoryStore(memoryStore)
// Note: depending on the memory manager, `maxMemory` may actually vary over time.
@@ -304,7 +354,8 @@ private[spark] class BlockManager(
shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
} else {
getLocalBytes(blockId) match {
- case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer)
+ case Some(blockData) =>
+ new BlockManagerManagedBuffer(blockInfoManager, blockId, blockData, true)
case None =>
// If this block manager receives a request for a block that it doesn't have then it's
// likely that the master has outdated block statuses for this block. Therefore, we send
@@ -463,21 +514,22 @@ private[spark] class BlockManager(
val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId))
Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
} else if (level.useDisk && diskStore.contains(blockId)) {
+ val diskData = diskStore.getBytes(blockId)
val iterToReturn: Iterator[Any] = {
- val diskBytes = diskStore.getBytes(blockId)
if (level.deserialized) {
val diskValues = serializerManager.dataDeserializeStream(
blockId,
- diskBytes.toInputStream(dispose = true))(info.classTag)
+ diskData.toInputStream())(info.classTag)
maybeCacheDiskValuesInMemory(info, blockId, level, diskValues)
} else {
- val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes)
- .map {_.toInputStream(dispose = false)}
- .getOrElse { diskBytes.toInputStream(dispose = true) }
+ val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskData)
+ .map { _.toInputStream(dispose = false) }
+ .getOrElse { diskData.toInputStream() }
serializerManager.dataDeserializeStream(blockId, stream)(info.classTag)
}
}
- val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId))
+ val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn,
+ releaseLockAndDispose(blockId, diskData))
Some(new BlockResult(ci, DataReadMethod.Disk, info.size))
} else {
handleLocalReadFailure(blockId)
@@ -488,7 +540,7 @@ private[spark] class BlockManager(
/**
* Get block from the local block manager as serialized bytes.
*/
- def getLocalBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
+ def getLocalBytes(blockId: BlockId): Option[BlockData] = {
logDebug(s"Getting local block $blockId as bytes")
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
@@ -496,9 +548,9 @@ private[spark] class BlockManager(
val shuffleBlockResolver = shuffleManager.shuffleBlockResolver
// TODO: This should gracefully handle case where local block is not available. Currently
// downstream code will throw an exception.
- Option(
- new ChunkedByteBuffer(
- shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()))
+ val buf = new ChunkedByteBuffer(
+ shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())
+ Some(new ByteBufferBlockData(buf, true))
} else {
blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) }
}
@@ -510,7 +562,7 @@ private[spark] class BlockManager(
* Must be called while holding a read lock on the block.
* Releases the read lock upon exception; keeps the read lock upon successful return.
*/
- private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): ChunkedByteBuffer = {
+ private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): BlockData = {
val level = info.level
logDebug(s"Level for block $blockId is $level")
// In order, try to read the serialized bytes from memory, then from disk, then fall back to
@@ -525,17 +577,19 @@ private[spark] class BlockManager(
diskStore.getBytes(blockId)
} else if (level.useMemory && memoryStore.contains(blockId)) {
// The block was not found on disk, so serialize an in-memory copy:
- serializerManager.dataSerializeWithExplicitClassTag(
- blockId, memoryStore.getValues(blockId).get, info.classTag)
+ new ByteBufferBlockData(serializerManager.dataSerializeWithExplicitClassTag(
+ blockId, memoryStore.getValues(blockId).get, info.classTag), true)
} else {
handleLocalReadFailure(blockId)
}
} else { // storage level is serialized
if (level.useMemory && memoryStore.contains(blockId)) {
- memoryStore.getBytes(blockId).get
+ new ByteBufferBlockData(memoryStore.getBytes(blockId).get, false)
} else if (level.useDisk && diskStore.contains(blockId)) {
- val diskBytes = diskStore.getBytes(blockId)
- maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes).getOrElse(diskBytes)
+ val diskData = diskStore.getBytes(blockId)
+ maybeCacheDiskBytesInMemory(info, blockId, level, diskData)
+ .map(new ByteBufferBlockData(_, false))
+ .getOrElse(diskData)
} else {
handleLocalReadFailure(blockId)
}
@@ -761,43 +815,15 @@ private[spark] class BlockManager(
* '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing
* so may corrupt or change the data stored by the `BlockManager`.
*
- * @param encrypt If true, asks the block manager to encrypt the data block before storing,
- * when I/O encryption is enabled. This is required for blocks that have been
- * read from unencrypted sources, since all the BlockManager read APIs
- * automatically do decryption.
* @return true if the block was stored or false if an error occurred.
*/
def putBytes[T: ClassTag](
blockId: BlockId,
bytes: ChunkedByteBuffer,
level: StorageLevel,
- tellMaster: Boolean = true,
- encrypt: Boolean = false): Boolean = {
+ tellMaster: Boolean = true): Boolean = {
require(bytes != null, "Bytes is null")
-
- val bytesToStore =
- if (encrypt && securityManager.ioEncryptionKey.isDefined) {
- try {
- val data = bytes.toByteBuffer
- val in = new ByteBufferInputStream(data)
- val byteBufOut = new ByteBufferOutputStream(data.remaining())
- val out = CryptoStreamUtils.createCryptoOutputStream(byteBufOut, conf,
- securityManager.ioEncryptionKey.get)
- try {
- ByteStreams.copy(in, out)
- } finally {
- in.close()
- out.close()
- }
- new ChunkedByteBuffer(byteBufOut.toByteBuffer)
- } finally {
- bytes.dispose()
- }
- } else {
- bytes
- }
-
- doPutBytes(blockId, bytesToStore, level, implicitly[ClassTag[T]], tellMaster)
+ doPutBytes(blockId, bytes, level, implicitly[ClassTag[T]], tellMaster)
}
/**
@@ -828,8 +854,9 @@ private[spark] class BlockManager(
val replicationFuture = if (level.replication > 1) {
Future {
// This is a blocking action and should run in futureExecutionContext which is a cached
- // thread pool
- replicate(blockId, bytes, level, classTag)
+ // thread pool. The ByteBufferBlockData wrapper is not disposed of to avoid releasing
+ // buffers that are owned by the caller.
+ replicate(blockId, new ByteBufferBlockData(bytes, false), level, classTag)
}(futureExecutionContext)
} else {
null
@@ -1008,8 +1035,9 @@ private[spark] class BlockManager(
// Not enough space to unroll this block; drop to disk if applicable
if (level.useDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
- diskStore.put(blockId) { fileOutputStream =>
- serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag)
+ diskStore.put(blockId) { channel =>
+ val out = Channels.newOutputStream(channel)
+ serializerManager.dataSerializeStream(blockId, out, iter)(classTag)
}
size = diskStore.getSize(blockId)
} else {
@@ -1024,8 +1052,9 @@ private[spark] class BlockManager(
// Not enough space to unroll this block; drop to disk if applicable
if (level.useDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
- diskStore.put(blockId) { fileOutputStream =>
- partiallySerializedValues.finishWritingToStream(fileOutputStream)
+ diskStore.put(blockId) { channel =>
+ val out = Channels.newOutputStream(channel)
+ partiallySerializedValues.finishWritingToStream(out)
}
size = diskStore.getSize(blockId)
} else {
@@ -1035,8 +1064,9 @@ private[spark] class BlockManager(
}
} else if (level.useDisk) {
- diskStore.put(blockId) { fileOutputStream =>
- serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag)
+ diskStore.put(blockId) { channel =>
+ val out = Channels.newOutputStream(channel)
+ serializerManager.dataSerializeStream(blockId, out, iterator())(classTag)
}
size = diskStore.getSize(blockId)
}
@@ -1065,7 +1095,7 @@ private[spark] class BlockManager(
try {
replicate(blockId, bytesToReplicate, level, remoteClassTag)
} finally {
- bytesToReplicate.unmap()
+ bytesToReplicate.dispose()
}
logDebug("Put block %s remotely took %s"
.format(blockId, Utils.getUsedTimeMs(remoteStartTime)))
@@ -1089,29 +1119,29 @@ private[spark] class BlockManager(
blockInfo: BlockInfo,
blockId: BlockId,
level: StorageLevel,
- diskBytes: ChunkedByteBuffer): Option[ChunkedByteBuffer] = {
+ diskData: BlockData): Option[ChunkedByteBuffer] = {
require(!level.deserialized)
if (level.useMemory) {
// Synchronize on blockInfo to guard against a race condition where two readers both try to
// put values read from disk into the MemoryStore.
blockInfo.synchronized {
if (memoryStore.contains(blockId)) {
- diskBytes.dispose()
+ diskData.dispose()
Some(memoryStore.getBytes(blockId).get)
} else {
val allocator = level.memoryMode match {
case MemoryMode.ON_HEAP => ByteBuffer.allocate _
case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _
}
- val putSucceeded = memoryStore.putBytes(blockId, diskBytes.size, level.memoryMode, () => {
+ val putSucceeded = memoryStore.putBytes(blockId, diskData.size, level.memoryMode, () => {
// https://issues.apache.org/jira/browse/SPARK-6076
// If the file size is bigger than the free memory, OOM will happen. So if we
// cannot put it into MemoryStore, copyForMemory should not be created. That's why
// this action is put into a `() => ChunkedByteBuffer` and created lazily.
- diskBytes.copy(allocator)
+ diskData.toChunkedByteBuffer(allocator)
})
if (putSucceeded) {
- diskBytes.dispose()
+ diskData.dispose()
Some(memoryStore.getBytes(blockId).get)
} else {
None
@@ -1203,7 +1233,7 @@ private[spark] class BlockManager(
replicate(blockId, data, storageLevel, info.classTag, existingReplicas)
} finally {
logDebug(s"Releasing lock for $blockId")
- releaseLock(blockId)
+ releaseLockAndDispose(blockId, data)
}
}
}
@@ -1214,7 +1244,7 @@ private[spark] class BlockManager(
*/
private def replicate(
blockId: BlockId,
- data: ChunkedByteBuffer,
+ data: BlockData,
level: StorageLevel,
classTag: ClassTag[_],
existingReplicas: Set[BlockManagerId] = Set.empty): Unit = {
@@ -1256,7 +1286,7 @@ private[spark] class BlockManager(
peer.port,
peer.executorId,
blockId,
- new NettyManagedBuffer(data.toNetty),
+ new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false),
tLevel,
classTag)
logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" +
@@ -1339,10 +1369,11 @@ private[spark] class BlockManager(
logInfo(s"Writing block $blockId to disk")
data() match {
case Left(elements) =>
- diskStore.put(blockId) { fileOutputStream =>
+ diskStore.put(blockId) { channel =>
+ val out = Channels.newOutputStream(channel)
serializerManager.dataSerializeStream(
blockId,
- fileOutputStream,
+ out,
elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]])
}
case Right(bytes) =>
@@ -1434,6 +1465,11 @@ private[spark] class BlockManager(
}
}
+ def releaseLockAndDispose(blockId: BlockId, data: BlockData): Unit = {
+ blockInfoManager.unlock(blockId)
+ data.dispose()
+ }
+
def stop(): Unit = {
blockTransferService.close()
if (shuffleClient ne blockTransferService) {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala
index f66f942798..1ea0d378cb 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala
@@ -17,31 +17,52 @@
package org.apache.spark.storage
-import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer}
+import java.io.InputStream
+import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.util.io.ChunkedByteBuffer
/**
- * This [[ManagedBuffer]] wraps a [[ChunkedByteBuffer]] retrieved from the [[BlockManager]]
+ * This [[ManagedBuffer]] wraps a [[BlockData]] instance retrieved from the [[BlockManager]]
* so that the corresponding block's read lock can be released once this buffer's references
* are released.
*
+ * If `dispose` is set to true, the [[BlockData]]will be disposed when the buffer's reference
+ * count drops to zero.
+ *
* This is effectively a wrapper / bridge to connect the BlockManager's notion of read locks
* to the network layer's notion of retain / release counts.
*/
private[storage] class BlockManagerManagedBuffer(
blockInfoManager: BlockInfoManager,
blockId: BlockId,
- chunkedBuffer: ChunkedByteBuffer) extends NettyManagedBuffer(chunkedBuffer.toNetty) {
+ data: BlockData,
+ dispose: Boolean) extends ManagedBuffer {
+
+ private val refCount = new AtomicInteger(1)
+
+ override def size(): Long = data.size
+
+ override def nioByteBuffer(): ByteBuffer = data.toByteBuffer()
+
+ override def createInputStream(): InputStream = data.toInputStream()
+
+ override def convertToNetty(): Object = data.toNetty()
override def retain(): ManagedBuffer = {
- super.retain()
+ refCount.incrementAndGet()
val locked = blockInfoManager.lockForReading(blockId, blocking = false)
assert(locked.isDefined)
this
- }
+ }
override def release(): ManagedBuffer = {
blockInfoManager.unlock(blockId)
- super.release()
+ if (refCount.decrementAndGet() == 0 && dispose) {
+ data.dispose()
+ }
+ this
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index ca23e2391e..c6656341fc 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -17,48 +17,67 @@
package org.apache.spark.storage
-import java.io.{FileOutputStream, IOException, RandomAccessFile}
+import java.io._
import java.nio.ByteBuffer
+import java.nio.channels.{Channels, ReadableByteChannel, WritableByteChannel}
import java.nio.channels.FileChannel.MapMode
+import java.nio.charset.StandardCharsets.UTF_8
+import java.util.concurrent.ConcurrentHashMap
-import com.google.common.io.Closeables
+import scala.collection.mutable.ListBuffer
-import org.apache.spark.SparkConf
+import com.google.common.io.{ByteStreams, Closeables, Files}
+import io.netty.channel.FileRegion
+import io.netty.util.AbstractReferenceCounted
+
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.internal.Logging
-import org.apache.spark.util.Utils
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.security.CryptoStreamUtils
+import org.apache.spark.util.{ByteBufferInputStream, Utils}
import org.apache.spark.util.io.ChunkedByteBuffer
/**
* Stores BlockManager blocks on disk.
*/
-private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) extends Logging {
+private[spark] class DiskStore(
+ conf: SparkConf,
+ diskManager: DiskBlockManager,
+ securityManager: SecurityManager) extends Logging {
private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m")
+ private val blockSizes = new ConcurrentHashMap[String, Long]()
- def getSize(blockId: BlockId): Long = {
- diskManager.getFile(blockId.name).length
- }
+ def getSize(blockId: BlockId): Long = blockSizes.get(blockId.name)
/**
* Invokes the provided callback function to write the specific block.
*
* @throws IllegalStateException if the block already exists in the disk store.
*/
- def put(blockId: BlockId)(writeFunc: FileOutputStream => Unit): Unit = {
+ def put(blockId: BlockId)(writeFunc: WritableByteChannel => Unit): Unit = {
if (contains(blockId)) {
throw new IllegalStateException(s"Block $blockId is already present in the disk store")
}
logDebug(s"Attempting to put block $blockId")
val startTime = System.currentTimeMillis
val file = diskManager.getFile(blockId)
- val fileOutputStream = new FileOutputStream(file)
+ val out = new CountingWritableChannel(openForWrite(file))
var threwException: Boolean = true
try {
- writeFunc(fileOutputStream)
+ writeFunc(out)
+ blockSizes.put(blockId.name, out.getCount)
threwException = false
} finally {
try {
- Closeables.close(fileOutputStream, threwException)
+ out.close()
+ } catch {
+ case ioe: IOException =>
+ if (!threwException) {
+ threwException = true
+ throw ioe
+ }
} finally {
if (threwException) {
remove(blockId)
@@ -73,41 +92,46 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e
}
def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = {
- put(blockId) { fileOutputStream =>
- val channel = fileOutputStream.getChannel
- Utils.tryWithSafeFinally {
- bytes.writeFully(channel)
- } {
- channel.close()
- }
+ put(blockId) { channel =>
+ bytes.writeFully(channel)
}
}
- def getBytes(blockId: BlockId): ChunkedByteBuffer = {
+ def getBytes(blockId: BlockId): BlockData = {
val file = diskManager.getFile(blockId.name)
- val channel = new RandomAccessFile(file, "r").getChannel
- Utils.tryWithSafeFinally {
- // For small files, directly read rather than memory map
- if (file.length < minMemoryMapBytes) {
- val buf = ByteBuffer.allocate(file.length.toInt)
- channel.position(0)
- while (buf.remaining() != 0) {
- if (channel.read(buf) == -1) {
- throw new IOException("Reached EOF before filling buffer\n" +
- s"offset=0\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}")
+ val blockSize = getSize(blockId)
+
+ securityManager.getIOEncryptionKey() match {
+ case Some(key) =>
+ // Encrypted blocks cannot be memory mapped; return a special object that does decryption
+ // and provides InputStream / FileRegion implementations for reading the data.
+ new EncryptedBlockData(file, blockSize, conf, key)
+
+ case _ =>
+ val channel = new FileInputStream(file).getChannel()
+ if (blockSize < minMemoryMapBytes) {
+ // For small files, directly read rather than memory map.
+ Utils.tryWithSafeFinally {
+ val buf = ByteBuffer.allocate(blockSize.toInt)
+ JavaUtils.readFully(channel, buf)
+ buf.flip()
+ new ByteBufferBlockData(new ChunkedByteBuffer(buf), true)
+ } {
+ channel.close()
+ }
+ } else {
+ Utils.tryWithSafeFinally {
+ new ByteBufferBlockData(
+ new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)), true)
+ } {
+ channel.close()
}
}
- buf.flip()
- new ChunkedByteBuffer(buf)
- } else {
- new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length))
- }
- } {
- channel.close()
}
}
def remove(blockId: BlockId): Boolean = {
+ blockSizes.remove(blockId.name)
val file = diskManager.getFile(blockId.name)
if (file.exists()) {
val ret = file.delete()
@@ -124,4 +148,142 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e
val file = diskManager.getFile(blockId.name)
file.exists()
}
+
+ private def openForWrite(file: File): WritableByteChannel = {
+ val out = new FileOutputStream(file).getChannel()
+ try {
+ securityManager.getIOEncryptionKey().map { key =>
+ CryptoStreamUtils.createWritableChannel(out, conf, key)
+ }.getOrElse(out)
+ } catch {
+ case e: Exception =>
+ Closeables.close(out, true)
+ file.delete()
+ throw e
+ }
+ }
+
+}
+
+private class EncryptedBlockData(
+ file: File,
+ blockSize: Long,
+ conf: SparkConf,
+ key: Array[Byte]) extends BlockData {
+
+ override def toInputStream(): InputStream = Channels.newInputStream(open())
+
+ override def toNetty(): Object = new ReadableChannelFileRegion(open(), blockSize)
+
+ override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = {
+ val source = open()
+ try {
+ var remaining = blockSize
+ val chunks = new ListBuffer[ByteBuffer]()
+ while (remaining > 0) {
+ val chunkSize = math.min(remaining, Int.MaxValue)
+ val chunk = allocator(chunkSize.toInt)
+ remaining -= chunkSize
+ JavaUtils.readFully(source, chunk)
+ chunk.flip()
+ chunks += chunk
+ }
+
+ new ChunkedByteBuffer(chunks.toArray)
+ } finally {
+ source.close()
+ }
+ }
+
+ override def toByteBuffer(): ByteBuffer = {
+ // This is used by the block transfer service to replicate blocks. The upload code reads
+ // all bytes into memory to send the block to the remote executor, so it's ok to do this
+ // as long as the block fits in a Java array.
+ assert(blockSize <= Int.MaxValue, "Block is too large to be wrapped in a byte buffer.")
+ val dst = ByteBuffer.allocate(blockSize.toInt)
+ val in = open()
+ try {
+ JavaUtils.readFully(in, dst)
+ dst.flip()
+ dst
+ } finally {
+ Closeables.close(in, true)
+ }
+ }
+
+ override def size: Long = blockSize
+
+ override def dispose(): Unit = { }
+
+ private def open(): ReadableByteChannel = {
+ val channel = new FileInputStream(file).getChannel()
+ try {
+ CryptoStreamUtils.createReadableChannel(channel, conf, key)
+ } catch {
+ case e: Exception =>
+ Closeables.close(channel, true)
+ throw e
+ }
+ }
+
+}
+
+private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long)
+ extends AbstractReferenceCounted with FileRegion {
+
+ private var _transferred = 0L
+
+ private val buffer = ByteBuffer.allocateDirect(64 * 1024)
+ buffer.flip()
+
+ override def count(): Long = blockSize
+
+ override def position(): Long = 0
+
+ override def transfered(): Long = _transferred
+
+ override def transferTo(target: WritableByteChannel, pos: Long): Long = {
+ assert(pos == transfered(), "Invalid position.")
+
+ var written = 0L
+ var lastWrite = -1L
+ while (lastWrite != 0) {
+ if (!buffer.hasRemaining()) {
+ buffer.clear()
+ source.read(buffer)
+ buffer.flip()
+ }
+ if (buffer.hasRemaining()) {
+ lastWrite = target.write(buffer)
+ written += lastWrite
+ } else {
+ lastWrite = 0
+ }
+ }
+
+ _transferred += written
+ written
+ }
+
+ override def deallocate(): Unit = source.close()
+}
+
+private class CountingWritableChannel(sink: WritableByteChannel) extends WritableByteChannel {
+
+ private var count = 0L
+
+ def getCount: Long = count
+
+ override def write(src: ByteBuffer): Int = {
+ val written = sink.write(src)
+ if (written > 0) {
+ count += written
+ }
+ written
+ }
+
+ override def isOpen(): Boolean = sink.isOpen()
+
+ override def close(): Unit = sink.close()
+
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
index 5efdd23f79..241aacd74b 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
@@ -236,14 +236,6 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) {
/** Helper methods for storage-related objects. */
private[spark] object StorageUtils extends Logging {
- // Ewwww... Reflection!!! See the unmap method for justification
- private val memoryMappedBufferFileDescriptorField = {
- val mappedBufferClass = classOf[java.nio.MappedByteBuffer]
- val fdField = mappedBufferClass.getDeclaredField("fd")
- fdField.setAccessible(true)
- fdField
- }
-
/**
* Attempt to clean up a ByteBuffer if it is direct or memory-mapped. This uses an *unsafe* Sun
* API that will cause errors if one attempts to read from the disposed buffer. However, neither
@@ -251,8 +243,6 @@ private[spark] object StorageUtils extends Logging {
* pressure on the garbage collector. Waiting for garbage collection may lead to the depletion of
* off-heap memory or huge numbers of open files. There's unfortunately no standard API to
* manually dispose of these kinds of buffers.
- *
- * See also [[unmap]]
*/
def dispose(buffer: ByteBuffer): Unit = {
if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) {
@@ -261,28 +251,6 @@ private[spark] object StorageUtils extends Logging {
}
}
- /**
- * Attempt to unmap a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that will
- * cause errors if one attempts to read from the unmapped buffer. However, the file descriptors of
- * memory-mapped buffers do not put pressure on the garbage collector. Waiting for garbage
- * collection may lead to huge numbers of open files. There's unfortunately no standard API to
- * manually unmap memory-mapped buffers.
- *
- * See also [[dispose]]
- */
- def unmap(buffer: ByteBuffer): Unit = {
- if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) {
- // Note that direct buffers are instances of MappedByteBuffer. As things stand in Java 8, the
- // JDK does not provide a public API to distinguish between direct buffers and memory-mapped
- // buffers. As an alternative, we peek beneath the curtains and look for a non-null file
- // descriptor in mappedByteBuffer
- if (memoryMappedBufferFileDescriptorField.get(buffer) != null) {
- logTrace(s"Unmapping $buffer")
- cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer])
- }
- }
- }
-
private def cleanDirectBuffer(buffer: DirectBuffer) = {
val cleaner = buffer.cleaner()
if (cleaner != null) {
diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index fb54dd66a3..90e3af2d0e 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -344,7 +344,7 @@ private[spark] class MemoryStore(
val serializationStream: SerializationStream = {
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val ser = serializerManager.getSerializer(classTag, autoPick).newInstance()
- ser.serializeStream(serializerManager.wrapStream(blockId, redirectableStream))
+ ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream))
}
// Request enough memory to begin unrolling
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index 1667516663..2f905c8af0 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -138,8 +138,6 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
/**
* Attempt to clean up any ByteBuffer in this ChunkedByteBuffer which is direct or memory-mapped.
* See [[StorageUtils.dispose]] for more information.
- *
- * See also [[unmap]]
*/
def dispose(): Unit = {
if (!disposed) {
@@ -148,18 +146,6 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
}
}
- /**
- * Attempt to unmap any ByteBuffer in this ChunkedByteBuffer if it is memory-mapped. See
- * [[StorageUtils.unmap]] for more information.
- *
- * See also [[dispose]]
- */
- def unmap(): Unit = {
- if (!disposed) {
- chunks.foreach(StorageUtils.unmap)
- disposed = true
- }
- }
}
/**
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 4e36adc8ba..84f7f1fc8e 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.concurrent.Timeouts._
import org.scalatest.Matchers
import org.scalatest.time.{Millis, Span}
+import org.apache.spark.security.EncryptionFunSuite
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.io.ChunkedByteBuffer
@@ -28,7 +29,8 @@ class NotSerializableClass
class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {}
-class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext {
+class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext
+ with EncryptionFunSuite {
val clusterUrl = "local-cluster[2,1,1024]"
@@ -149,8 +151,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
sc.parallelize(1 to 10).count()
}
- private def testCaching(storageLevel: StorageLevel): Unit = {
- sc = new SparkContext(clusterUrl, "test")
+ private def testCaching(conf: SparkConf, storageLevel: StorageLevel): Unit = {
+ sc = new SparkContext(conf.setMaster(clusterUrl).setAppName("test"))
sc.jobProgressListener.waitUntilExecutorsUp(2, 30000)
val data = sc.parallelize(1 to 1000, 10)
val cachedData = data.persist(storageLevel)
@@ -187,8 +189,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
"caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2,
"caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2
).foreach { case (testName, storageLevel) =>
- test(testName) {
- testCaching(storageLevel)
+ encryptionTest(testName) { conf =>
+ testCaching(conf, storageLevel)
}
}
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index 6646068d50..82760fe92f 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -24,8 +24,10 @@ import org.scalatest.Assertions
import org.apache.spark._
import org.apache.spark.io.SnappyCompressionCodec
import org.apache.spark.rdd.RDD
+import org.apache.spark.security.EncryptionFunSuite
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage._
+import org.apache.spark.util.io.ChunkedByteBuffer
// Dummy class that creates a broadcast variable but doesn't use it
class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable {
@@ -43,7 +45,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable {
}
}
-class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
+class BroadcastSuite extends SparkFunSuite with LocalSparkContext with EncryptionFunSuite {
test("Using TorrentBroadcast locally") {
sc = new SparkContext("local", "test")
@@ -61,9 +63,8 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet)
}
- test("Accessing TorrentBroadcast variables in a local cluster") {
+ encryptionTest("Accessing TorrentBroadcast variables in a local cluster") { conf =>
val numSlaves = 4
- val conf = new SparkConf
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.broadcast.compress", "true")
sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf)
@@ -85,7 +86,9 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
val size = 1 + rand.nextInt(1024 * 10)
val data: Array[Byte] = new Array[Byte](size)
rand.nextBytes(data)
- val blocks = blockifyObject(data, blockSize, serializer, compressionCodec)
+ val blocks = blockifyObject(data, blockSize, serializer, compressionCodec).map { b =>
+ new ChunkedByteBuffer(b).toInputStream(dispose = true)
+ }
val unblockified = unBlockifyObject[Array[Byte]](blocks, serializer, compressionCodec)
assert(unblockified === data)
}
@@ -137,9 +140,8 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
sc.stop()
}
- test("Cache broadcast to disk") {
- val conf = new SparkConf()
- .setMaster("local")
+ encryptionTest("Cache broadcast to disk") { conf =>
+ conf.setMaster("local")
.setAppName("test")
.set("spark.memory.useLegacyMode", "true")
.set("spark.storage.memoryFraction", "0.0")
diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
index 0f3a4a0361..608052f5ed 100644
--- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
@@ -16,9 +16,11 @@
*/
package org.apache.spark.security
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream}
+import java.nio.channels.Channels
import java.nio.charset.StandardCharsets.UTF_8
-import java.util.UUID
+import java.nio.file.Files
+import java.util.{Arrays, Random, UUID}
import com.google.common.io.ByteStreams
@@ -121,6 +123,46 @@ class CryptoStreamUtilsSuite extends SparkFunSuite {
}
}
+ test("crypto stream wrappers") {
+ val testData = new Array[Byte](128 * 1024)
+ new Random().nextBytes(testData)
+
+ val conf = createConf()
+ val key = createKey(conf)
+ val file = Files.createTempFile("crypto", ".test").toFile()
+
+ val outStream = createCryptoOutputStream(new FileOutputStream(file), conf, key)
+ try {
+ ByteStreams.copy(new ByteArrayInputStream(testData), outStream)
+ } finally {
+ outStream.close()
+ }
+
+ val inStream = createCryptoInputStream(new FileInputStream(file), conf, key)
+ try {
+ val inStreamData = ByteStreams.toByteArray(inStream)
+ assert(Arrays.equals(inStreamData, testData))
+ } finally {
+ inStream.close()
+ }
+
+ val outChannel = createWritableChannel(new FileOutputStream(file).getChannel(), conf, key)
+ try {
+ val inByteChannel = Channels.newChannel(new ByteArrayInputStream(testData))
+ ByteStreams.copy(inByteChannel, outChannel)
+ } finally {
+ outChannel.close()
+ }
+
+ val inChannel = createReadableChannel(new FileInputStream(file).getChannel(), conf, key)
+ try {
+ val inChannelData = ByteStreams.toByteArray(Channels.newInputStream(inChannel))
+ assert(Arrays.equals(inChannelData, testData))
+ } finally {
+ inChannel.close()
+ }
+ }
+
private def createConf(extra: (String, String)*): SparkConf = {
val conf = new SparkConf()
extra.foreach { case (k, v) => conf.set(k, v) }
diff --git a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala
new file mode 100644
index 0000000000..3f52dc41ab
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.security
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.internal.config._
+
+trait EncryptionFunSuite {
+
+ this: SparkFunSuite =>
+
+ /**
+ * Runs a test twice, initializing a SparkConf object with encryption off, then on. It's ok
+ * for the test to modify the provided SparkConf.
+ */
+ final protected def encryptionTest(name: String)(fn: SparkConf => Unit) {
+ Seq(false, true).foreach { encrypt =>
+ test(s"$name (encryption = ${ if (encrypt) "on" else "off" })") {
+ val conf = new SparkConf().set(IO_ENCRYPTION_ENABLED, encrypt)
+ fn(conf)
+ }
+ }
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 64a67b4c4c..a8b9604899 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -35,6 +35,7 @@ import org.scalatest.concurrent.Timeouts._
import org.apache.spark._
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.executor.DataReadMethod
+import org.apache.spark.internal.config._
import org.apache.spark.memory.UnifiedMemoryManager
import org.apache.spark.network.{BlockDataManager, BlockTransferService}
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
@@ -42,6 +43,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.LiveListenerBus
+import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
@@ -49,7 +51,8 @@ import org.apache.spark.util._
import org.apache.spark.util.io.ChunkedByteBuffer
class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach
- with PrivateMethodTester with LocalSparkContext with ResetSystemProperties {
+ with PrivateMethodTester with LocalSparkContext with ResetSystemProperties
+ with EncryptionFunSuite {
import BlockManagerSuite._
@@ -75,16 +78,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
maxMem: Long,
name: String = SparkContext.DRIVER_IDENTIFIER,
master: BlockManagerMaster = this.master,
- transferService: Option[BlockTransferService] = Option.empty): BlockManager = {
- conf.set("spark.testing.memory", maxMem.toString)
- conf.set("spark.memory.offHeap.size", maxMem.toString)
- val serializer = new KryoSerializer(conf)
+ transferService: Option[BlockTransferService] = Option.empty,
+ testConf: Option[SparkConf] = None): BlockManager = {
+ val bmConf = testConf.map(_.setAll(conf.getAll)).getOrElse(conf)
+ bmConf.set("spark.testing.memory", maxMem.toString)
+ bmConf.set("spark.memory.offHeap.size", maxMem.toString)
+ val serializer = new KryoSerializer(bmConf)
+ val encryptionKey = if (bmConf.get(IO_ENCRYPTION_ENABLED)) {
+ Some(CryptoStreamUtils.createKey(bmConf))
+ } else {
+ None
+ }
+ val bmSecurityMgr = new SecurityManager(bmConf, encryptionKey)
val transfer = transferService
.getOrElse(new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1))
- val memManager = UnifiedMemoryManager(conf, numCores = 1)
- val serializerManager = new SerializerManager(serializer, conf)
- val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, conf,
- memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
+ val memManager = UnifiedMemoryManager(bmConf, numCores = 1)
+ val serializerManager = new SerializerManager(serializer, bmConf)
+ val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, bmConf,
+ memManager, mapOutputTracker, shuffleManager, transfer, bmSecurityMgr, 0)
memManager.setMemoryStore(blockManager.memoryStore)
blockManager.initialize("app-id")
blockManager
@@ -610,8 +621,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store")
}
- test("on-disk storage") {
- store = makeBlockManager(1200)
+ encryptionTest("on-disk storage") { _conf =>
+ store = makeBlockManager(1200, testConf = Some(_conf))
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -623,34 +634,35 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was in store")
}
- test("disk and memory storage") {
- testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false)
+ encryptionTest("disk and memory storage") { _conf =>
+ testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false, testConf = conf)
}
- test("disk and memory storage with getLocalBytes") {
- testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true)
+ encryptionTest("disk and memory storage with getLocalBytes") { _conf =>
+ testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true, testConf = conf)
}
- test("disk and memory storage with serialization") {
- testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false)
+ encryptionTest("disk and memory storage with serialization") { _conf =>
+ testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false, testConf = conf)
}
- test("disk and memory storage with serialization and getLocalBytes") {
- testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true)
+ encryptionTest("disk and memory storage with serialization and getLocalBytes") { _conf =>
+ testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true, testConf = conf)
}
- test("disk and off-heap memory storage") {
- testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false)
+ encryptionTest("disk and off-heap memory storage") { _conf =>
+ testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false, testConf = conf)
}
- test("disk and off-heap memory storage with getLocalBytes") {
- testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true)
+ encryptionTest("disk and off-heap memory storage with getLocalBytes") { _conf =>
+ testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true, testConf = conf)
}
def testDiskAndMemoryStorage(
storageLevel: StorageLevel,
- getAsBytes: Boolean): Unit = {
- store = makeBlockManager(12000)
+ getAsBytes: Boolean,
+ testConf: SparkConf): Unit = {
+ store = makeBlockManager(12000, testConf = Some(testConf))
val accessMethod =
if (getAsBytes) store.getLocalBytesAndReleaseLock else store.getSingleAndReleaseLock
val a1 = new Array[Byte](4000)
@@ -678,8 +690,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
}
}
- test("LRU with mixed storage levels") {
- store = makeBlockManager(12000)
+ encryptionTest("LRU with mixed storage levels") { _conf =>
+ store = makeBlockManager(12000, testConf = Some(_conf))
val a1 = new Array[Byte](4000)
val a2 = new Array[Byte](4000)
val a3 = new Array[Byte](4000)
@@ -700,8 +712,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(store.getSingleAndReleaseLock("a4").isDefined, "a4 was not in store")
}
- test("in-memory LRU with streams") {
- store = makeBlockManager(12000)
+ encryptionTest("in-memory LRU with streams") { _conf =>
+ store = makeBlockManager(12000, testConf = Some(_conf))
val list1 = List(new Array[Byte](2000), new Array[Byte](2000))
val list2 = List(new Array[Byte](2000), new Array[Byte](2000))
val list3 = List(new Array[Byte](2000), new Array[Byte](2000))
@@ -728,8 +740,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(store.getAndReleaseLock("list3") === None, "list1 was in store")
}
- test("LRU with mixed storage levels and streams") {
- store = makeBlockManager(12000)
+ encryptionTest("LRU with mixed storage levels and streams") { _conf =>
+ store = makeBlockManager(12000, testConf = Some(_conf))
val list1 = List(new Array[Byte](2000), new Array[Byte](2000))
val list2 = List(new Array[Byte](2000), new Array[Byte](2000))
val list3 = List(new Array[Byte](2000), new Array[Byte](2000))
@@ -1325,7 +1337,8 @@ private object BlockManagerSuite {
val getAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.get)
val getSingleAndReleaseLock: (BlockId) => Option[Any] = wrapGet(store.getSingle)
val getLocalBytesAndReleaseLock: (BlockId) => Option[ChunkedByteBuffer] = {
- wrapGet(store.getLocalBytes)
+ val allocator = ByteBuffer.allocate _
+ wrapGet { bid => store.getLocalBytes(bid).map(_.toChunkedByteBuffer(allocator)) }
}
}
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala
index 9e6b02b9ea..67fc084e8a 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala
@@ -18,15 +18,23 @@
package org.apache.spark.storage
import java.nio.{ByteBuffer, MappedByteBuffer}
-import java.util.Arrays
+import java.util.{Arrays, Random}
-import org.apache.spark.{SparkConf, SparkFunSuite}
+import com.google.common.io.{ByteStreams, Files}
+import io.netty.channel.FileRegion
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark.network.util.{ByteArrayWritableChannel, JavaUtils}
+import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.util.io.ChunkedByteBuffer
import org.apache.spark.util.Utils
class DiskStoreSuite extends SparkFunSuite {
test("reads of memory-mapped and non memory-mapped files are equivalent") {
+ val conf = new SparkConf()
+ val securityManager = new SecurityManager(conf)
+
// It will cause error when we tried to re-open the filestore and the
// memory-mapped byte buffer tot he file has not been GC on Windows.
assume(!Utils.isWindows)
@@ -37,16 +45,18 @@ class DiskStoreSuite extends SparkFunSuite {
val byteBuffer = new ChunkedByteBuffer(ByteBuffer.wrap(bytes))
val blockId = BlockId("rdd_1_2")
- val diskBlockManager = new DiskBlockManager(new SparkConf(), deleteFilesOnStop = true)
+ val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true)
- val diskStoreMapped = new DiskStore(new SparkConf().set(confKey, "0"), diskBlockManager)
+ val diskStoreMapped = new DiskStore(conf.clone().set(confKey, "0"), diskBlockManager,
+ securityManager)
diskStoreMapped.putBytes(blockId, byteBuffer)
- val mapped = diskStoreMapped.getBytes(blockId)
+ val mapped = diskStoreMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer
assert(diskStoreMapped.remove(blockId))
- val diskStoreNotMapped = new DiskStore(new SparkConf().set(confKey, "1m"), diskBlockManager)
+ val diskStoreNotMapped = new DiskStore(conf.clone().set(confKey, "1m"), diskBlockManager,
+ securityManager)
diskStoreNotMapped.putBytes(blockId, byteBuffer)
- val notMapped = diskStoreNotMapped.getBytes(blockId)
+ val notMapped = diskStoreNotMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer
// Not possible to do isInstanceOf due to visibility of HeapByteBuffer
assert(notMapped.getChunks().forall(_.getClass.getName.endsWith("HeapByteBuffer")),
@@ -63,4 +73,95 @@ class DiskStoreSuite extends SparkFunSuite {
assert(Arrays.equals(mapped.toArray, bytes))
assert(Arrays.equals(notMapped.toArray, bytes))
}
+
+ test("block size tracking") {
+ val conf = new SparkConf()
+ val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true)
+ val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf))
+
+ val blockId = BlockId("rdd_1_2")
+ diskStore.put(blockId) { chan =>
+ val buf = ByteBuffer.wrap(new Array[Byte](32))
+ while (buf.hasRemaining()) {
+ chan.write(buf)
+ }
+ }
+
+ assert(diskStore.getSize(blockId) === 32L)
+ diskStore.remove(blockId)
+ assert(diskStore.getSize(blockId) === 0L)
+ }
+
+ test("block data encryption") {
+ val testDir = Utils.createTempDir()
+ val testData = new Array[Byte](128 * 1024)
+ new Random().nextBytes(testData)
+
+ val conf = new SparkConf()
+ val securityManager = new SecurityManager(conf, Some(CryptoStreamUtils.createKey(conf)))
+ val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true)
+ val diskStore = new DiskStore(conf, diskBlockManager, securityManager)
+
+ val blockId = BlockId("rdd_1_2")
+ diskStore.put(blockId) { chan =>
+ val buf = ByteBuffer.wrap(testData)
+ while (buf.hasRemaining()) {
+ chan.write(buf)
+ }
+ }
+
+ assert(diskStore.getSize(blockId) === testData.length)
+
+ val diskData = Files.toByteArray(diskBlockManager.getFile(blockId.name))
+ assert(!Arrays.equals(testData, diskData))
+
+ val blockData = diskStore.getBytes(blockId)
+ assert(blockData.isInstanceOf[EncryptedBlockData])
+ assert(blockData.size === testData.length)
+ Map(
+ "input stream" -> readViaInputStream _,
+ "chunked byte buffer" -> readViaChunkedByteBuffer _,
+ "nio byte buffer" -> readViaNioBuffer _,
+ "managed buffer" -> readViaManagedBuffer _
+ ).foreach { case (name, fn) =>
+ val readData = fn(blockData)
+ assert(readData.length === blockData.size, s"Size of data read via $name did not match.")
+ assert(Arrays.equals(testData, readData), s"Data read via $name did not match.")
+ }
+ }
+
+ private def readViaInputStream(data: BlockData): Array[Byte] = {
+ val is = data.toInputStream()
+ try {
+ ByteStreams.toByteArray(is)
+ } finally {
+ is.close()
+ }
+ }
+
+ private def readViaChunkedByteBuffer(data: BlockData): Array[Byte] = {
+ val buf = data.toChunkedByteBuffer(ByteBuffer.allocate _)
+ try {
+ buf.toArray
+ } finally {
+ buf.dispose()
+ }
+ }
+
+ private def readViaNioBuffer(data: BlockData): Array[Byte] = {
+ JavaUtils.bufferToArray(data.toByteBuffer())
+ }
+
+ private def readViaManagedBuffer(data: BlockData): Array[Byte] = {
+ val region = data.toNetty().asInstanceOf[FileRegion]
+ val byteChannel = new ByteArrayWritableChannel(data.size.toInt)
+
+ while (region.transfered() < region.count()) {
+ region.transferTo(byteChannel, region.transfered())
+ }
+
+ byteChannel.close()
+ byteChannel.getData
+ }
+
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
index d0864fd367..844760ab61 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
@@ -158,16 +158,14 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
logInfo(s"Read partition data of $this from write ahead log, record handle " +
partition.walRecordHandle)
if (storeInBlockManager) {
- blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel,
- encrypt = true)
+ blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel)
logDebug(s"Stored partition data of $this into block manager with level $storageLevel")
dataRead.rewind()
}
serializerManager
.dataDeserializeStream(
blockId,
- new ChunkedByteBuffer(dataRead).toInputStream(),
- maybeEncrypted = false)(elementClassTag)
+ new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag)
.asInstanceOf[Iterator[T]]
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
index 2b488038f0..80c07958b4 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
@@ -87,8 +87,7 @@ private[streaming] class BlockManagerBasedBlockHandler(
putResult
case ByteBufferBlock(byteBuffer) =>
blockManager.putBytes(
- blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true,
- encrypt = true)
+ blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true)
case o =>
throw new SparkException(
s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}")
@@ -176,11 +175,10 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
val serializedBlock = block match {
case ArrayBufferBlock(arrayBuffer) =>
numRecords = Some(arrayBuffer.size.toLong)
- serializerManager.dataSerialize(blockId, arrayBuffer.iterator, allowEncryption = false)
+ serializerManager.dataSerialize(blockId, arrayBuffer.iterator)
case IteratorBlock(iterator) =>
val countIterator = new CountingIterator(iterator)
- val serializedBlock = serializerManager.dataSerialize(blockId, countIterator,
- allowEncryption = false)
+ val serializedBlock = serializerManager.dataSerialize(blockId, countIterator)
numRecords = countIterator.count
serializedBlock
case ByteBufferBlock(byteBuffer) =>
@@ -195,8 +193,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
blockId,
serializedBlock,
effectiveStorageLevel,
- tellMaster = true,
- encrypt = true)
+ tellMaster = true)
if (!putSucceeded) {
throw new SparkException(
s"Could not store $blockId to block manager with storage level $storageLevel")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index c2b0389b8c..3c4a2716ca 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -175,8 +175,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean)
reader.close()
serializerManager.dataDeserializeStream(
generateBlockId(),
- new ChunkedByteBuffer(bytes).toInputStream(),
- maybeEncrypted = false)(ClassTag.Any).toList
+ new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList
}
loggedData shouldEqual data
}
@@ -357,7 +356,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean)
}
def dataToByteBuffer(b: Seq[String]) =
- serializerManager.dataSerialize(generateBlockId, b.iterator, allowEncryption = false)
+ serializerManager.dataSerialize(generateBlockId, b.iterator)
val blocks = data.grouped(10).toSeq
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
index 2ac0dc9691..aa69be7ca9 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
@@ -250,8 +250,7 @@ class WriteAheadLogBackedBlockRDDSuite
require(blockData.size === blockIds.size)
val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf)
val segments = blockData.zip(blockIds).map { case (data, id) =>
- writer.write(serializerManager.dataSerialize(id, data.iterator, allowEncryption = false)
- .toByteBuffer)
+ writer.write(serializerManager.dataSerialize(id, data.iterator).toByteBuffer)
}
writer.close()
segments