diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2012-09-29 20:21:54 -0700 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2012-09-29 20:21:54 -0700 |
commit | 9b326d01e9a9ec4a4a9abf293cf039c07d426293 (patch) | |
tree | 67283c4ae24bf48014715a19129c60833280c389 /core/src | |
parent | 56dcad593641ef8de211fcb4303574a9f4509f89 (diff) | |
download | spark-9b326d01e9a9ec4a4a9abf293cf039c07d426293.tar.gz spark-9b326d01e9a9ec4a4a9abf293cf039c07d426293.tar.bz2 spark-9b326d01e9a9ec4a4a9abf293cf039c07d426293.zip |
Made BlockManager unmap memory-mapped files when necessary to reduce the
number of open files. Also optimized sending of disk-based blocks.
Diffstat (limited to 'core/src')
-rw-r--r-- | core/src/main/scala/spark/CacheTracker.scala | 1 | ||||
-rw-r--r-- | core/src/main/scala/spark/network/Connection.scala | 18 | ||||
-rw-r--r-- | core/src/main/scala/spark/network/Message.scala | 7 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockManager.scala | 125 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockManagerWorker.scala | 15 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockStore.scala | 13 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/DiskStore.scala | 45 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/MemoryStore.scala | 72 | ||||
-rw-r--r-- | core/src/main/scala/spark/util/ByteBufferInputStream.scala | 41 | ||||
-rw-r--r-- | core/src/test/scala/spark/DistributedSuite.scala | 60 |
10 files changed, 279 insertions, 118 deletions
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 9f88f93269..225a5ad403 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -158,7 +158,6 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl // For BlockManager.scala only def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) { communicate(t) - logInfo("notifyTheCacheTrackerFromBlockManager successful") } // Get a snapshot of the currently known locations diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 0209f4b29d..c4350173fc 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -137,9 +137,12 @@ extends Connection(SocketChannel.open, selector_) { if (!message.started) logDebug("Starting to send [" + message + "]") message.started = true return chunk + } else { + /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ + message.finishTime = System.currentTimeMillis + logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + + "] in " + message.timeTaken ) } - /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ - logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) } } None @@ -162,10 +165,11 @@ extends Connection(SocketChannel.open, selector_) { } logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]") return chunk - } - /*messages -= message*/ - message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) + } else { + message.finishTime = System.currentTimeMillis + logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + + "] in " + message.timeTaken ) + } } } None @@ -219,7 +223,7 @@ extends Connection(SocketChannel.open, selector_) { while(true) { if (currentBuffers.size == 0) { outbox.synchronized { - outbox.getChunk match { + outbox.getChunk() match { case Some(chunk) => { currentBuffers ++= chunk.buffers } diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala index 2e85803679..62a06d95c5 100644 --- a/core/src/main/scala/spark/network/Message.scala +++ b/core/src/main/scala/spark/network/Message.scala @@ -7,6 +7,7 @@ import scala.collection.mutable.ArrayBuffer import java.nio.ByteBuffer import java.net.InetAddress import java.net.InetSocketAddress +import storage.BlockManager class MessageChunkHeader( val typ: Long, @@ -64,7 +65,7 @@ abstract class Message(val typ: Long, val id: Int) { def timeTaken(): String = (finishTime - startTime).toString + " ms" - override def toString = "" + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" + override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" } class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) @@ -97,10 +98,11 @@ extends Message(Message.BUFFER_MESSAGE, id_) { while(!buffers.isEmpty) { val buffer = buffers(0) if (buffer.remaining == 0) { + BlockManager.dispose(buffer) buffers -= buffer } else { val newBuffer = if (buffer.remaining <= maxChunkSize) { - buffer.duplicate + buffer.duplicate() } else { buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] } @@ -147,7 +149,6 @@ extends Message(Message.BUFFER_MESSAGE, id_) { } else { "BufferMessage(id = " + id + ", size = " + size + ")" } - } } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index bae5c8c567..224c55d9d7 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -6,7 +6,7 @@ import akka.util.Duration import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream import java.io.{InputStream, OutputStream, Externalizable, ObjectInput, ObjectOutput} -import java.nio.ByteBuffer +import java.nio.{MappedByteBuffer, ByteBuffer} import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} @@ -16,6 +16,7 @@ import spark.{CacheTracker, Logging, Serializer, SizeEstimator, SparkException, import spark.network._ import spark.util.ByteBufferInputStream import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} +import sun.nio.ch.DirectBuffer class BlockManagerId(var ip: String, var port: Int) extends Externalizable { @@ -179,12 +180,8 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * Get block from local block manager. */ def getLocal(blockId: String): Option[Iterator[Any]] = { - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } logDebug("Getting local block " + blockId) locker.getLock(blockId).synchronized { - // Check storage level of block val level = getLevel(blockId) if (level != null) { @@ -202,11 +199,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m logDebug("Block " + blockId + " not found in memory") } } - } else { - logDebug("Not getting block " + blockId + " from memory") } - // Look for block in disk + // Look for block on disk if (level.useDisk) { logDebug("Getting block " + blockId + " from disk") diskStore.getValues(blockId) match { @@ -215,22 +210,65 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m return Some(iterator) } case None => { - throw new Exception("Block " + blockId + " not found in disk") + throw new Exception("Block " + blockId + " not found on disk, though it should be") return None } } - } else { - logDebug("Not getting block " + blockId + " from disk") } - } else { - logDebug("Level for block " + blockId + " not found") + logDebug("Block " + blockId + " not registered locally") } } return None } /** + * Get block from the local block manager as serialized bytes. + */ + def getLocalBytes(blockId: String): Option[ByteBuffer] = { + logDebug("Getting local block " + blockId + " as bytes") + locker.getLock(blockId).synchronized { + // Check storage level of block + val level = getLevel(blockId) + if (level != null) { + logDebug("Level for block " + blockId + " is " + level + " on local machine") + + // Look for the block in memory + if (level.useMemory) { + logDebug("Getting block " + blockId + " from memory") + memoryStore.getBytes(blockId) match { + case Some(bytes) => { + logDebug("Block " + blockId + " found in memory") + return Some(bytes) + } + case None => { + logDebug("Block " + blockId + " not found in memory") + } + } + } + + // Look for block on disk + if (level.useDisk) { + logDebug("Getting block " + blockId + " from disk") + diskStore.getBytes(blockId) match { + case Some(bytes) => { + logDebug("Block " + blockId + " found in disk") + return Some(bytes) + } + case None => { + throw new Exception("Block " + blockId + " not found on disk, though it should be") + return None + } + } + } + } else { + logDebug("Block " + blockId + " not registered locally") + } + } + return None + } + + /** * Get block from remote block managers. */ def getRemote(blockId: String): Option[Iterator[Any]] = { @@ -416,9 +454,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m if (level.useMemory && level.useDisk) { // If saving to both memory and disk, then serialize only once - memoryStore.putValues(blockId, values, level) match { + memoryStore.putValues(blockId, values, level, true) match { case Left(newValues) => - diskStore.putValues(blockId, newValues, level) match { + diskStore.putValues(blockId, newValues, level, true) match { case Right(newBytes) => bytes = newBytes case _ => throw new Exception("Unexpected return value") } @@ -428,15 +466,16 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } else if (level.useMemory) { // If only save to memory - memoryStore.putValues(blockId, values, level) match { + memoryStore.putValues(blockId, values, level, true) match { case Right(newBytes) => bytes = newBytes case Left(newIterator) => valuesAfterPut = newIterator } } else { // If only save to disk - diskStore.putValues(blockId, values, level) match { + val askForBytes = level.replication > 1 // Don't get back the bytes unless we replicate them + diskStore.putValues(blockId, values, level, askForBytes) match { case Right(newBytes) => bytes = newBytes - case _ => throw new Exception("Unexpected return value") + case _ => } } @@ -458,6 +497,8 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m replicate(blockId, bytes, level) } + BlockManager.dispose(bytes) + // TODO: This code will be removed when CacheTracker is gone. if (blockId.startsWith("rdd")) { notifyTheCacheTracker(blockId) @@ -527,7 +568,6 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m Await.ready(replicationFuture, Duration.Inf) } - val finishTime = System.currentTimeMillis if (level.replication > 1) { logDebug("PutBytes for block " + blockId + " with replication took " + Utils.getUsedTimeMs(startTimeMs)) @@ -540,17 +580,14 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m /** * Replicate block to another node. */ - - var firstTime = true - var peers : Seq[BlockManagerId] = null + var cachedPeers: Seq[BlockManagerId] = null private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { val tLevel: StorageLevel = new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) - if (firstTime) { - peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) - firstTime = false; + if (cachedPeers == null) { + cachedPeers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) } - for (peer: BlockManagerId <- peers) { + for (peer: BlockManagerId <- cachedPeers) { val start = System.nanoTime data.rewind() logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " @@ -570,7 +607,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val rddInfo = key.split(":") val rddId: Int = rddInfo(1).toInt val splitIndex: Int = rddInfo(2).toInt - val host = System.getProperty("spark.hostname", Utils.localHostName) + val host = System.getProperty("spark.hostname", Utils.localHostName()) cacheTracker.notifyTheCacheTrackerFromBlockManager(spark.AddedToCache(rddId, splitIndex, host)) } @@ -578,7 +615,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * Read a block consisting of a single object. */ def getSingle(blockId: String): Option[Any] = { - get(blockId).map(_.next) + get(blockId).map(_.next()) } /** @@ -608,18 +645,21 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - /** Wrap an output stream for compression if block compression is enabled */ + /** + * Wrap an output stream for compression if block compression is enabled + */ def wrapForCompression(s: OutputStream): OutputStream = { if (compress) new LZFOutputStream(s) else s } - /** Wrap an input stream for compression if block compression is enabled */ + /** + * Wrap an input stream for compression if block compression is enabled + */ def wrapForCompression(s: InputStream): InputStream = { if (compress) new LZFInputStream(s) else s } def dataSerialize(values: Iterator[Any]): ByteBuffer = { - /*serializer.newInstance().serializeMany(values)*/ val byteStream = new FastByteArrayOutputStream(4096) val ser = serializer.newInstance() ser.serializeStream(wrapForCompression(byteStream)).writeAll(values).close() @@ -627,10 +667,14 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m ByteBuffer.wrap(byteStream.array) } + /** + * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of + * the iterator is reached. + */ def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = { bytes.rewind() val ser = serializer.newInstance() - return ser.deserializeStream(wrapForCompression(new ByteBufferInputStream(bytes))).asIterator + ser.deserializeStream(wrapForCompression(new ByteBufferInputStream(bytes, true))).asIterator } def stop() { @@ -642,8 +686,8 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } -object BlockManager { - +private[spark] +object BlockManager extends Logging { def getNumParallelFetchesFromSystemProperties: Int = { System.getProperty("spark.blockManager.parallelFetches", "4").toInt } @@ -652,4 +696,17 @@ object BlockManager { val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble (Runtime.getRuntime.maxMemory * memoryFraction).toLong } + + /** + * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that + * might cause errors if one attempts to read from the unmapped buffer, but it's better than + * waiting for the GC to find it because that could lead to huge numbers of open files. There's + * unfortunately no standard API to do this. + */ + def dispose(buffer: ByteBuffer) { + if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { + logDebug("Unmapping " + buffer) + buffer.asInstanceOf[DirectBuffer].cleaner().clean() + } + } } diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala index 0ad1ad056c..47e4d14010 100644 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala @@ -76,17 +76,10 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging { private def getBlock(id: String): ByteBuffer = { val startTimeMs = System.currentTimeMillis() - logDebug("Getblock " + id + " started from " + startTimeMs) - val block = blockManager.getLocal(id) - val buffer = block match { - case Some(tValues) => { - val values = tValues - val buffer = blockManager.dataSerialize(values) - buffer - } - case None => { - null - } + logDebug("GetBlock " + id + " started from " + startTimeMs) + val buffer = blockManager.getLocalBytes(id) match { + case Some(bytes) => bytes + case None => null } logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala index 64773a3b03..5f123aca78 100644 --- a/core/src/main/scala/spark/storage/BlockStore.scala +++ b/core/src/main/scala/spark/storage/BlockStore.scala @@ -7,14 +7,17 @@ import spark.Logging /** * Abstract class to store blocks */ +private[spark] abstract class BlockStore(val blockManager: BlockManager) extends Logging { def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) /** - * Put in a block and return its content as either bytes or another Iterator. This is used - * to efficiently write the values to multiple locations (e.g. for replication). + * Put in a block and, possibly, also return its content as either bytes or another Iterator. + * This is used to efficiently write the values to multiple locations (e.g. for replication). + * + * @return the values put if returnValues is true, or null otherwise */ - def putValues(blockId: String, values: Iterator[Any], level: StorageLevel) + def putValues(blockId: String, values: Iterator[Any], level: StorageLevel, returnValues: Boolean) : Either[Iterator[Any], ByteBuffer] /** @@ -28,9 +31,5 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging { def remove(blockId: String) - def dataSerialize(values: Iterator[Any]): ByteBuffer = blockManager.dataSerialize(values) - - def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = blockManager.dataDeserialize(bytes) - def clear() { } } diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index e46cbdff16..34bb989485 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -6,11 +6,12 @@ import java.nio.channels.FileChannel.MapMode import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import java.util.UUID import spark.Utils +import java.nio.channels.FileChannel /** * Stores BlockManager blocks on disk. */ -class DiskStore(blockManager: BlockManager, rootDirs: String) +private class DiskStore(blockManager: BlockManager, rootDirs: String) extends BlockStore(blockManager) { val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @@ -33,15 +34,20 @@ class DiskStore(blockManager: BlockManager, rootDirs: String) val startTime = System.currentTimeMillis val file = createFile(blockId) val channel = new RandomAccessFile(file, "rw").getChannel() - val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.limit) - buffer.put(bytes) + while (bytes.remaining > 0) { + channel.write(bytes) + } channel.close() val finishTime = System.currentTimeMillis logDebug("Block %s stored to file of %d bytes to disk in %d ms".format( blockId, bytes.limit, (finishTime - startTime))) } - override def putValues(blockId: String, values: Iterator[Any], level: StorageLevel) + override def putValues( + blockId: String, + values: Iterator[Any], + level: StorageLevel, + returnValues: Boolean) : Either[Iterator[Any], ByteBuffer] = { logDebug("Attempting to write values for block " + blockId) @@ -52,30 +58,35 @@ class DiskStore(blockManager: BlockManager, rootDirs: String) objOut.writeAll(values) objOut.close() - // Return a byte buffer for the contents of the file - val channel = new RandomAccessFile(file, "rw").getChannel() - Right(channel.map(MapMode.READ_WRITE, 0, channel.size())) + if (returnValues) { + // Return a byte buffer for the contents of the file + val channel = new RandomAccessFile(file, "r").getChannel() + val buffer = channel.map(MapMode.READ_ONLY, 0, channel.size()) + channel.close() + Right(buffer) + } else { + null + } } override def getBytes(blockId: String): Option[ByteBuffer] = { val file = getFile(blockId) val length = file.length().toInt val channel = new RandomAccessFile(file, "r").getChannel() - Some(channel.map(MapMode.READ_WRITE, 0, length)) + val bytes = channel.map(MapMode.READ_ONLY, 0, length) + channel.close() + Some(bytes) } override def getValues(blockId: String): Option[Iterator[Any]] = { - val file = getFile(blockId) - val length = file.length().toInt - val channel = new RandomAccessFile(file, "r").getChannel() - val bytes = channel.map(MapMode.READ_ONLY, 0, length) - val buffer = dataDeserialize(bytes) - channel.close() - Some(buffer) + getBytes(blockId).map(blockManager.dataDeserialize(_)) } override def remove(blockId: String) { - throw new UnsupportedOperationException("Not implemented") + val file = getFile(blockId) + if (file.exists()) { + file.delete() + } } private def createFile(blockId: String): File = { @@ -97,7 +108,7 @@ class DiskStore(blockManager: BlockManager, rootDirs: String) // Create the subdirectory if it doesn't already exist var subDir = subDirs(dirId)(subDirId) if (subDir == null) { - subDir = subDirs(dirId).synchronized { + subDir = subDirs(dirId).synchronized { val old = subDirs(dirId)(subDirId) if (old != null) { old diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala index 24a80b7f96..d71585b6e3 100644 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/spark/storage/MemoryStore.scala @@ -10,19 +10,19 @@ import collection.mutable.ArrayBuffer * Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as * serialized ByteBuffers. */ -class MemoryStore(blockManager: BlockManager, maxMemory: Long) +private class MemoryStore(blockManager: BlockManager, maxMemory: Long) extends BlockStore(blockManager) { case class Entry(value: Any, size: Long, deserialized: Boolean, var dropPending: Boolean = false) - private val memoryStore = new LinkedHashMap[String, Entry](32, 0.75f, true) + private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true) private var currentMemory = 0L //private val blockDropper = Executors.newSingleThreadExecutor() private val blocksToDrop = new ArrayBlockingQueue[String](10000, true) private val blockDropper = new Thread("memory store - block dropper") { override def run() { - try{ + try { while (true) { val blockId = blocksToDrop.take() logDebug("Block " + blockId + " ready to be dropped") @@ -40,35 +40,39 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) def freeMemory: Long = maxMemory - currentMemory override def getSize(blockId: String): Long = { - memoryStore.synchronized { - memoryStore.get(blockId).size + entries.synchronized { + entries.get(blockId).size } } override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { if (level.deserialized) { bytes.rewind() - val values = dataDeserialize(bytes) + val values = blockManager.dataDeserialize(bytes) val elements = new ArrayBuffer[Any] elements ++= values val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) ensureFreeSpace(sizeEstimate) val entry = new Entry(elements, sizeEstimate, true) - memoryStore.synchronized { memoryStore.put(blockId, entry) } + entries.synchronized { entries.put(blockId, entry) } currentMemory += sizeEstimate logInfo("Block %s stored as values to memory (estimated size %d, free %d)".format( blockId, sizeEstimate, freeMemory)) } else { val entry = new Entry(bytes, bytes.limit, false) ensureFreeSpace(bytes.limit) - memoryStore.synchronized { memoryStore.put(blockId, entry) } + entries.synchronized { entries.put(blockId, entry) } currentMemory += bytes.limit logInfo("Block %s stored as %d bytes to memory (free %d)".format( blockId, bytes.limit, freeMemory)) } } - override def putValues(blockId: String, values: Iterator[Any], level: StorageLevel) + override def putValues( + blockId: String, + values: Iterator[Any], + level: StorageLevel, + returnValues: Boolean) : Either[Iterator[Any], ByteBuffer] = { if (level.deserialized) { @@ -77,44 +81,55 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) ensureFreeSpace(sizeEstimate) val entry = new Entry(elements, sizeEstimate, true) - memoryStore.synchronized { memoryStore.put(blockId, entry) } + entries.synchronized { entries.put(blockId, entry) } currentMemory += sizeEstimate logInfo("Block %s stored as values to memory (estimated size %d, free %d)".format( blockId, sizeEstimate, freeMemory)) - return Left(elements.iterator) + Left(elements.iterator) } else { - val bytes = dataSerialize(values) + val bytes = blockManager.dataSerialize(values) ensureFreeSpace(bytes.limit) val entry = new Entry(bytes, bytes.limit, false) - memoryStore.synchronized { memoryStore.put(blockId, entry) } + entries.synchronized { entries.put(blockId, entry) } currentMemory += bytes.limit logInfo("Block %s stored as %d bytes to memory (free %d)".format( blockId, bytes.limit, freeMemory)) - return Right(bytes) + Right(bytes) } } override def getBytes(blockId: String): Option[ByteBuffer] = { - throw new UnsupportedOperationException("Not implemented") + val entry = entries.synchronized { + entries.get(blockId) + } + if (entry == null) { + None + } else if (entry.deserialized) { + Some(blockManager.dataSerialize(entry.value.asInstanceOf[ArrayBuffer[Any]].iterator)) + } else { + Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data + } } override def getValues(blockId: String): Option[Iterator[Any]] = { - val entry = memoryStore.synchronized { memoryStore.get(blockId) } - if (entry == null) { - return None + val entry = entries.synchronized { + entries.get(blockId) } - if (entry.deserialized) { - return Some(entry.value.asInstanceOf[ArrayBuffer[Any]].iterator) + if (entry == null) { + None + } else if (entry.deserialized) { + Some(entry.value.asInstanceOf[ArrayBuffer[Any]].iterator) } else { - return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer].duplicate())) + val buffer = entry.value.asInstanceOf[ByteBuffer].duplicate() // Doesn't actually copy data + Some(blockManager.dataDeserialize(buffer)) } } override def remove(blockId: String) { - memoryStore.synchronized { - val entry = memoryStore.get(blockId) + entries.synchronized { + val entry = entries.get(blockId) if (entry != null) { - memoryStore.remove(blockId) + entries.remove(blockId) currentMemory -= entry.size logInfo("Block %s of size %d dropped from memory (free %d)".format( blockId, entry.size, freeMemory)) @@ -125,10 +140,9 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def clear() { - memoryStore.synchronized { - memoryStore.clear() + entries.synchronized { + entries.clear() } - //blockDropper.shutdown() blockDropper.interrupt() logInfo("MemoryStore cleared") } @@ -142,8 +156,8 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) val selectedBlocks = new ArrayBuffer[String]() var selectedMemory = 0L - memoryStore.synchronized { - val iter = memoryStore.entrySet().iterator() + entries.synchronized { + val iter = entries.entrySet().iterator() while (maxMemory - (currentMemory - selectedMemory) < space && iter.hasNext) { val pair = iter.next() val blockId = pair.getKey diff --git a/core/src/main/scala/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/spark/util/ByteBufferInputStream.scala index c92b60a40c..d7e67497fe 100644 --- a/core/src/main/scala/spark/util/ByteBufferInputStream.scala +++ b/core/src/main/scala/spark/util/ByteBufferInputStream.scala @@ -2,10 +2,19 @@ package spark.util import java.io.InputStream import java.nio.ByteBuffer +import spark.storage.BlockManager + +/** + * Reads data from a ByteBuffer, and optionally cleans it up using BlockManager.dispose() + * at the end of the stream (e.g. to close a memory-mapped file). + */ +private[spark] +class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = false) + extends InputStream { -class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream { override def read(): Int = { - if (buffer.remaining() == 0) { + if (buffer == null || buffer.remaining() == 0) { + cleanUp() -1 } else { buffer.get() & 0xFF @@ -17,7 +26,8 @@ class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream { } override def read(dest: Array[Byte], offset: Int, length: Int): Int = { - if (buffer.remaining() == 0) { + if (buffer == null || buffer.remaining() == 0) { + cleanUp() -1 } else { val amountToGet = math.min(buffer.remaining(), length) @@ -27,10 +37,27 @@ class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream { } override def skip(bytes: Long): Long = { - val amountToSkip = math.min(bytes, buffer.remaining).toInt - buffer.position(buffer.position + amountToSkip) - return amountToSkip + if (buffer != null) { + val amountToSkip = math.min(bytes, buffer.remaining).toInt + buffer.position(buffer.position + amountToSkip) + if (buffer.remaining() == 0) { + cleanUp() + } + amountToSkip + } else { + 0L + } } - def position: Int = buffer.position + /** + * Clean up the buffer, and potentially dispose of it using BlockManager.dispose(). + */ + private def cleanUp() { + if (buffer != null) { + if (dispose) { + BlockManager.dispose(buffer) + } + buffer = null + } + } } diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 93b876d205..fce1deaa5c 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -13,6 +13,7 @@ import com.google.common.io.Files import scala.collection.mutable.ArrayBuffer import SparkContext._ +import storage.StorageLevel class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { @@ -26,7 +27,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter sc = null } } - + test("simple groupByKey") { sc = new SparkContext(clusterUrl, "test") val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 5) @@ -64,5 +65,60 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter assert(thrown.getClass === classOf[SparkException]) assert(thrown.getMessage.contains("more than 4 times")) } -} + test("caching") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).cache() + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching on disk") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching in memory, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_2) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching in memory, serialized, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_SER_2) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching on disk, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY_2) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching in memory and disk, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_2) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching in memory and disk, serialized, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_SER_2) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } +} |