From c6156da9e27a8a247555c7b1b498d384377c0351 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 13 Jun 2012 16:01:31 -0400 Subject: Multiple bug fixes to pass the testsuites ShuffleSuite and BlockManagerSuite. --- core/src/main/scala/spark/SparkContext.scala | 5 + .../scala/spark/network/ConnectionManager.scala | 11 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 3 + .../scala/spark/scheduler/ShuffleMapTask.scala | 7 + .../main/scala/spark/storage/BlockManager.scala | 206 ++++++++++++++------- .../scala/spark/storage/BlockManagerMaster.scala | 75 ++++---- .../scala/spark/storage/BlockManagerWorker.scala | 2 +- core/src/main/scala/spark/storage/BlockStore.scala | 12 +- .../main/scala/spark/storage/StorageLevel.scala | 4 +- core/src/test/scala/spark/ShuffleSuite.scala | 6 +- .../scala/spark/storage/BlockManagerSuite.scala | 64 +++++-- 11 files changed, 264 insertions(+), 131 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 7a9a70fee0..eae71571db 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -38,11 +38,13 @@ import spark.broadcast._ import spark.partial.ApproximateEvaluator import spark.partial.PartialResult +import spark.scheduler.ShuffleMapTask import spark.scheduler.DAGScheduler import spark.scheduler.TaskScheduler import spark.scheduler.local.LocalScheduler import spark.scheduler.mesos.MesosScheduler import spark.scheduler.mesos.CoarseMesosScheduler +import spark.storage.BlockManagerMaster class SparkContext( master: String, @@ -266,8 +268,11 @@ class SparkContext( env.cacheTracker.stop() env.shuffleFetcher.stop() env.shuffleManager.stop() + env.blockManager.stop() + BlockManagerMaster.stopBlockManagerMaster() env.connectionManager.stop() SparkEnv.set(null) + ShuffleMapTask.clearCache() } // Wait for the scheduler to be registered with the cluster manager diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index e9f254d0f3..f0b942c492 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -36,7 +36,6 @@ class ConnectionManager(port: Int) extends Logging { } val selector = SelectorProvider.provider.openSelector() - /*val handleMessageExecutor = new ThreadPoolExecutor(4, 4, 600, TimeUnit.SECONDS, new LinkedBlockingQueue()) */ val handleMessageExecutor = Executors.newFixedThreadPool(4) val serverChannel = ServerSocketChannel.open() val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] @@ -59,7 +58,7 @@ class ConnectionManager(port: Int) extends Logging { logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) val thisInstance = this - var selectorThread = new Thread("connection-manager-thread") { + val selectorThread = new Thread("connection-manager-thread") { override def run() { thisInstance.run() } @@ -331,9 +330,11 @@ class ConnectionManager(port: Int) extends Logging { } def stop() { - selectorThread.interrupt() - selectorThread.join() - selector.close() + if (!selectorThread.isAlive) { + selectorThread.interrupt() + selectorThread.join() + selector.close() + } val connections = connectionsByKey.values connections.foreach(_.close()) if (connectionsByKey.size != 0) { diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index f31e2c65a0..f9d53d3b5d 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -190,6 +190,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with allowLocal: Boolean) (implicit m: ClassManifest[U]): Array[U] = { + if (partitions.size == 0) { + return new Array[U](0) + } val waiter = new JobWaiter(partitions.size) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, waiter)) diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 317faa0851..79cca0f294 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -56,6 +56,13 @@ object ShuffleMapTask { } } } + + def clearCache() { + synchronized { + serializedInfoCache.clear() + deserializedInfoCache.clear() + } + } } class ShuffleMapTask( diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 367c79dd76..999bbc2128 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -27,6 +27,7 @@ import spark.SizeEstimator import spark.SparkEnv import spark.SparkException import spark.Utils +import spark.util.ByteBufferInputStream import spark.network._ class BlockManagerId(var ip: String, var port: Int) extends Externalizable { @@ -65,19 +66,15 @@ class BlockLocker(numLockers: Int) { } -/** - * A start towards a block manager class. This will eventually be used for both RDD persistence - * and shuffle outputs. - * - * TODO: Should make the communication with Master or Peers code more robust and log friendly. - */ + class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging { - + + case class BlockInfo(level: StorageLevel, tellMaster: Boolean) + private val NUM_LOCKS = 337 private val locker = new BlockLocker(NUM_LOCKS) - private val storageLevels = Collections.synchronizedMap(new JHashMap[String, StorageLevel]) - + private val blockInfo = Collections.synchronizedMap(new JHashMap[String, BlockInfo]) private val memoryStore: BlockStore = new MemoryStore(this, maxMemory) private val diskStore: BlockStore = new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) @@ -87,7 +84,7 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging val connectionManagerId = connectionManager.id val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) - // TODO(Haoyuan): This will be removed after cacheTracker is removed from the code base. + // TODO: This will be removed after cacheTracker is removed from the code base. var cacheTracker: CacheTracker = null initLogging() @@ -104,12 +101,54 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging * Initialize the BlockManager. Register to the BlockManagerMaster, and start the * BlockManagerWorker actor. */ - def initialize() { + private def initialize() { BlockManagerMaster.mustRegisterBlockManager( RegisterBlockManager(blockManagerId, maxMemory, maxMemory)) BlockManagerWorker.startBlockManagerWorker(this) } - + + /** + * Get storage level of local block. If no info exists for the block, then returns null. + */ + def getLevel(blockId: String): StorageLevel = { + val info = blockInfo.get(blockId) + if (info != null) info.level else null + } + + /** + * Change storage level for a local block and tell master is necesary. + * If new level is invalid, then block info (if it exists) will be silently removed. + */ + def setLevel(blockId: String, level: StorageLevel, tellMaster: Boolean = true) { + if (level == null) { + throw new IllegalArgumentException("Storage level is null") + } + + // If there was earlier info about the block, then use earlier tellMaster + val oldInfo = blockInfo.get(blockId) + val newTellMaster = if (oldInfo != null) oldInfo.tellMaster else tellMaster + if (oldInfo != null && oldInfo.tellMaster != tellMaster) { + logWarning("Ignoring tellMaster setting as it is different from earlier setting") + } + + // If level is valid, store the block info, else remove the block info + if (level.isValid) { + blockInfo.put(blockId, new BlockInfo(level, newTellMaster)) + logDebug("Info for block " + blockId + " updated with new level as " + level) + } else { + blockInfo.remove(blockId) + logDebug("Info for block " + blockId + " removed as new level is null or invalid") + } + + // Tell master if necessary + if (newTellMaster) { + logDebug("Told master about block " + blockId) + notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0)) + } else { + logDebug("Did not tell master about block " + blockId) + } + } + /** * Get locations of the block. */ @@ -122,9 +161,9 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging } /** - * Get locations of an array of blocks + * Get locations of an array of blocks. */ - def getLocationsMultipleBlockIds(blockIds: Array[String]): Array[Seq[String]] = { + def getLocations(blockIds: Array[String]): Array[Seq[String]] = { val startTimeMs = System.currentTimeMillis val locations = BlockManagerMaster.mustGetLocationsMultipleBlockIds( GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray @@ -132,12 +171,18 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging return locations } + /** + * Get block from local block manager. + */ def getLocal(blockId: String): Option[Iterator[Any]] = { - logDebug("Getting block " + blockId) + if (blockId == null) { + throw new IllegalArgumentException("Block Id is null") + } + logDebug("Getting local block " + blockId) locker.getLock(blockId).synchronized { // Check storage level of block - val level = storageLevels.get(blockId) + val level = getLevel(blockId) if (level != null) { logDebug("Level for block " + blockId + " is " + level + " on local machine") @@ -181,12 +226,20 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging return None } + /** + * Get block from remote block managers. + */ def getRemote(blockId: String): Option[Iterator[Any]] = { + if (blockId == null) { + throw new IllegalArgumentException("Block Id is null") + } + logDebug("Getting remote block " + blockId) // Get locations of block val locations = BlockManagerMaster.mustGetLocations(GetLocations(blockId)) // Get block from remote locations for (loc <- locations) { + logDebug("Getting remote block " + blockId + " from " + loc) val data = BlockManagerWorker.syncGetBlock( GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port)) if (data != null) { @@ -200,16 +253,19 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging } /** - * Read a block from the block manager. + * Get a block from the block manager (either local or remote). */ def get(blockId: String): Option[Iterator[Any]] = { getLocal(blockId).orElse(getRemote(blockId)) } /** - * Read many blocks from block manager using their BlockManagerIds. + * Get many blocks from local and remote block manager using their BlockManagerIds. */ def get(blocksByAddress: Seq[(BlockManagerId, Seq[String])]): HashMap[String, Option[Iterator[Any]]] = { + if (blocksByAddress == null) { + throw new IllegalArgumentException("BlocksByAddress is null") + } logDebug("Getting " + blocksByAddress.map(_._2.size).sum + " blocks") var startTime = System.currentTimeMillis val blocks = new HashMap[String,Option[Iterator[Any]]]() @@ -235,7 +291,8 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) (cmId, future) } - logDebug("Started remote gets for " + remoteBlockIds.size + " blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + logDebug("Started remote gets for " + remoteBlockIds.size + " blocks in " + + Utils.getUsedTimeMs(startTime) + " ms") // Get the local blocks while remote blocks are being fetched startTime = System.currentTimeMillis @@ -276,7 +333,8 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging throw new BlockException(oneBlockId, "Could not get blocks from " + cmId) } } - logDebug("Got remote " + count + " blocks from " + cmId.host + " in " + Utils.getUsedTimeMs(startTime) + " ms") + logDebug("Got remote " + count + " blocks from " + cmId.host + " in " + + Utils.getUsedTimeMs(startTime) + " ms") } logDebug("Got all blocks in " + Utils.getUsedTimeMs(startTime) + " ms") @@ -284,29 +342,32 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging } /** - * Write a new block to the block manager. + * Put a new block of values to the block manager. */ def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean = true) { - if (!level.useDisk && !level.useMemory) { - throw new IllegalArgumentException("Storage level has neither useMemory nor useDisk set") + if (blockId == null) { + throw new IllegalArgumentException("Block Id is null") + } + if (values == null) { + throw new IllegalArgumentException("Values is null") + } + if (level == null || !level.isValid) { + throw new IllegalArgumentException("Storage level is null or invalid") } val startTimeMs = System.currentTimeMillis var bytes: ByteBuffer = null locker.getLock(blockId).synchronized { - logDebug("Put for block " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") // Check and warn if block with same id already exists - if (storageLevels.get(blockId) != null) { + if (getLevel(blockId) != null) { logWarning("Block " + blockId + " already exists in local machine") return } - // Store the storage level - storageLevels.put(blockId, level) - if (level.useMemory && level.useDisk) { // If saving to both memory and disk, then serialize only once memoryStore.putValues(blockId, values, level) match { @@ -333,11 +394,10 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging } } - if (tellMaster) { - notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0)) - logDebug("Put block " + blockId + " after notifying the master " + Utils.getUsedTimeMs(startTimeMs)) - } + // Store the storage level + setLevel(blockId, level, tellMaster) } + logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) // Replicate block if required if (level.replication > 1) { @@ -347,21 +407,32 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging replicate(blockId, bytes, level) } - // TODO(Haoyuan): This code will be removed when CacheTracker is gone. + // TODO: This code will be removed when CacheTracker is gone. if (blockId.startsWith("rdd")) { notifyTheCacheTracker(blockId) } - logDebug("Put block " + blockId + " after notifying the CacheTracker " + Utils.getUsedTimeMs(startTimeMs)) + logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)) } + /** + * Put a new block of serialized bytes to the block manager. + */ def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { - val startTime = System.currentTimeMillis - if (!level.useDisk && !level.useMemory) { - throw new IllegalArgumentException("Storage level has neither useMemory nor useDisk set") - } else if (level.deserialized) { - throw new IllegalArgumentException("Storage level cannot have deserialized when putBytes is used") + if (blockId == null) { + throw new IllegalArgumentException("Block Id is null") + } + if (bytes == null) { + throw new IllegalArgumentException("Bytes is null") + } + if (level == null || !level.isValid) { + throw new IllegalArgumentException("Storage level is null or invalid") } + + val startTimeMs = System.currentTimeMillis + + // Initiate the replication before storing it locally. This is faster as + // data is already serialized and ready for sending val replicationFuture = if (level.replication > 1) { future { replicate(blockId, bytes, level) @@ -371,13 +442,12 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging } locker.getLock(blockId).synchronized { - logDebug("PutBytes for block " + blockId + " used " + Utils.getUsedTimeMs(startTime) + logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") - if (storageLevels.get(blockId) != null) { + if (getLevel(blockId) != null) { logWarning("Block " + blockId + " already exists") return } - storageLevels.put(blockId, level) if (level.useMemory) { memoryStore.putBytes(blockId, bytes, level) @@ -385,15 +455,17 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging if (level.useDisk) { diskStore.putBytes(blockId, bytes, level) } - if (tellMaster) { - notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0)) - } + + // Store the storage level + setLevel(blockId, level, tellMaster) } + // TODO: This code will be removed when CacheTracker is gone. if (blockId.startsWith("rdd")) { notifyTheCacheTracker(blockId) } - + + // If replication had started, then wait for it to finish if (level.replication > 1) { if (replicationFuture == null) { throw new Exception("Unexpected") @@ -403,13 +475,18 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging val finishTime = System.currentTimeMillis if (level.replication > 1) { - logDebug("PutBytes with replication took " + (finishTime - startTime) + " ms") + logDebug("PutBytes for block " + blockId + " with replication took " + + Utils.getUsedTimeMs(startTimeMs)) } else { - logDebug("PutBytes without replication took " + (finishTime - startTime) + " ms") + logDebug("PutBytes for block " + blockId + " without replication took " + + Utils.getUsedTimeMs(startTimeMs)) } - } + /** + * Replicate block to another node. + */ + private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { val tLevel: StorageLevel = new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) @@ -429,8 +506,8 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging } } - // TODO(Haoyuan): This code will be removed when CacheTracker is gone. - def notifyTheCacheTracker(key: String) { + // TODO: This code will be removed when CacheTracker is gone. + private def notifyTheCacheTracker(key: String) { val rddInfo = key.split(":") val rddId: Int = rddInfo(1).toInt val splitIndex: Int = rddInfo(2).toInt @@ -448,8 +525,8 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging /** * Write a block consisting of a single object. */ - def putSingle(blockId: String, value: Any, level: StorageLevel) { - put(blockId, Iterator(value), level) + def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) { + put(blockId, Iterator(value), level, tellMaster) } /** @@ -457,7 +534,7 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging */ def dropFromMemory(blockId: String) { locker.getLock(blockId).synchronized { - val level = storageLevels.get(blockId) + val level = getLevel(blockId) if (level == null) { logWarning("Block " + blockId + " cannot be removed from memory as it does not exist") return @@ -467,14 +544,8 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging return } memoryStore.remove(blockId) - if (!level.useDisk) { - storageLevels.remove(blockId) - } else { - val newLevel = level.clone - newLevel.useMemory = false - storageLevels.remove(blockId) - storageLevels.put(blockId, newLevel) - } + val newLevel = new StorageLevel(level.useDisk, false, level.deserialized, level.replication) + setLevel(blockId, newLevel) } } @@ -489,14 +560,23 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = { /*serializer.newInstance().deserializeMany(bytes)*/ val ser = serializer.newInstance() - return ser.deserializeStream(new FastByteArrayInputStream(bytes.array())).toIterator + bytes.rewind() + return ser.deserializeStream(new ByteBufferInputStream(bytes)).toIterator } private def notifyMaster(heartBeat: HeartBeat) { BlockManagerMaster.mustHeartBeat(heartBeat) } + + def stop() { + connectionManager.stop() + blockInfo.clear() + memoryStore.clear() + diskStore.clear() + } } + object BlockManager extends Logging { def getMaxMemoryFromSystemProperties(): Long = { val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index bd94c185e9..85edbbe0cd 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -1,6 +1,7 @@ package spark.storage import java.io._ +import java.util.{HashMap => JHashMap} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -86,7 +87,9 @@ case class RemoveHost( host: String) extends ToBlockManagerMaster + class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { + class BlockManagerInfo( timeMs: Long, maxMem: Long, @@ -94,7 +97,7 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { private var lastSeenMs = timeMs private var remainedMem = maxMem private var remainedDisk = maxDisk - private val blocks = new HashMap[String, StorageLevel] + private val blocks = new JHashMap[String, StorageLevel] def updateLastSeenMs() { lastSeenMs = System.currentTimeMillis() / 1000 @@ -104,8 +107,8 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { synchronized { updateLastSeenMs() - if (blocks.contains(blockId)) { - val oriLevel: StorageLevel = blocks(blockId) + if (blocks.containsKey(blockId)) { + val oriLevel: StorageLevel = blocks.get(blockId) if (oriLevel.deserialized) { remainedMem += deserializedSize @@ -117,20 +120,19 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { remainedDisk += size } } - - blocks += (blockId -> storageLevel) - - if (storageLevel.deserialized) { - remainedMem -= deserializedSize - } - if (storageLevel.useMemory) { - remainedMem -= size - } - if (storageLevel.useDisk) { - remainedDisk -= size - } - if (!(storageLevel.deserialized || storageLevel.useMemory || storageLevel.useDisk)) { + if (storageLevel.isValid) { + blocks.put(blockId, storageLevel) + if (storageLevel.deserialized) { + remainedMem -= deserializedSize + } + if (storageLevel.useMemory) { + remainedMem -= size + } + if (storageLevel.useDisk) { + remainedDisk -= size + } + } else { blocks.remove(blockId) } } @@ -150,10 +152,14 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { override def toString(): String = { return "BlockManagerInfo " + timeMs + " " + remainedMem + " " + remainedDisk } + + def clear() { + blocks.clear() + } } private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerInfo] - private val blockIdMap = new HashMap[String, Pair[Int, HashSet[BlockManagerId]]] + private val blockInfo = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] initLogging() @@ -215,7 +221,6 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockManagerId + " " + blockId + " " - logDebug("Got in heartBeat 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) if (blockId == null) { blockManagerInfo(blockManagerId).updateLastSeenMs() @@ -224,29 +229,24 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { } blockManagerInfo(blockManagerId).addBlock(blockId, storageLevel, deserializedSize, size) - logDebug("Got in heartBeat 2" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) var locations: HashSet[BlockManagerId] = null - if (blockIdMap.contains(blockId)) { - locations = blockIdMap(blockId)._2 + if (blockInfo.containsKey(blockId)) { + locations = blockInfo.get(blockId)._2 } else { locations = new HashSet[BlockManagerId] - blockIdMap += (blockId -> (storageLevel.replication, locations)) + blockInfo.put(blockId, (storageLevel.replication, locations)) } - logDebug("Got in heartBeat 3" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) - if (storageLevel.deserialized || storageLevel.useDisk || storageLevel.useMemory) { + if (storageLevel.isValid) { locations += blockManagerId } else { locations.remove(blockManagerId) } - logDebug("Got in heartBeat 4" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) if (locations.size == 0) { - blockIdMap.remove(blockId) + blockInfo.remove(blockId) } - - logDebug("Got in heartBeat 5" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) self.reply(true) } @@ -254,9 +254,9 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockId + " " logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) - if (blockIdMap.contains(blockId)) { + if (blockInfo.containsKey(blockId)) { var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockIdMap(blockId)._2) + res.appendAll(blockInfo.get(blockId)._2) logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at " + Utils.getUsedTimeMs(startTimeMs)) self.reply(res.toSeq) @@ -271,9 +271,9 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { def getLocations(blockId: String): Seq[BlockManagerId] = { val tmp = blockId logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp) - if (blockIdMap.contains(blockId)) { + if (blockInfo.containsKey(blockId)) { var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockIdMap(blockId)._2) + res.appendAll(blockInfo.get(blockId)._2) logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq) return res.toSeq } else { @@ -293,24 +293,18 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { } private def getPeers(blockManagerId: BlockManagerId, size: Int) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " - logDebug("Got in getPeers 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] res.appendAll(peers) res -= blockManagerId val rand = new Random(System.currentTimeMillis()) - logDebug("Got in getPeers 1" + tmp + Utils.getUsedTimeMs(startTimeMs)) while (res.length > size) { res.remove(rand.nextInt(res.length)) } - logDebug("Got in getPeers 2" + tmp + Utils.getUsedTimeMs(startTimeMs)) self.reply(res.toSeq) } private def getPeers_Deterministic(blockManagerId: BlockManagerId, size: Int) { - val startTimeMs = System.currentTimeMillis() var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] @@ -329,7 +323,6 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { res += peers(index % peers.size) } val resStr = res.map(_.toString).reduceLeft(_ + ", " + _) - logDebug("Got peers for " + blockManagerId + " as [" + resStr + "]") self.reply(res.toSeq) } } @@ -358,6 +351,10 @@ object BlockManagerMaster extends Logging { } } + def stopBlockManagerMaster() { + if (masterActor != null) masterActor.stop() + } + def notifyADeadHost(host: String) { (masterActor ? RemoveHost(host + ":" + DEFAULT_MANAGER_PORT)).as[Any] match { case Some(true) => diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala index a4cdbd8ddd..3a8574a815 100644 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala @@ -79,7 +79,7 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging { private def getBlock(id: String): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("Getblock " + id + " started from " + startTimeMs) - val block = blockManager.get(id) + val block = blockManager.getLocal(id) val buffer = block match { case Some(tValues) => { val values = tValues.asInstanceOf[Iterator[Any]] diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala index 0584cc2d4f..52f2cc32e8 100644 --- a/core/src/main/scala/spark/storage/BlockStore.scala +++ b/core/src/main/scala/spark/storage/BlockStore.scala @@ -31,6 +31,8 @@ abstract class BlockStore(blockManager: BlockManager) extends Logging { def dataSerialize(values: Iterator[Any]): ByteBuffer = blockManager.dataSerialize(values) def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = blockManager.dataDeserialize(bytes) + + def clear() { } } /** @@ -118,6 +120,13 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } + override def clear() { + memoryStore.synchronized { + memoryStore.clear() + } + blockDropper.shutdown() + } + private def drop(blockId: String) { blockDropper.submit(new Runnable() { def run() { @@ -147,8 +156,7 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) for (blockId <- droppedBlockIds) { drop(blockId) } - - droppedBlockIds.clear + droppedBlockIds.clear() } } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index a2833a7090..693a679c4e 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -32,7 +32,9 @@ class StorageLevel( case _ => false } - + + def isValid() = ((useMemory || useDisk) && (replication > 0)) + def toInt(): Int = { var ret = 0 if (useDisk) { diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index c61cb90f82..00b24464a6 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -48,7 +48,7 @@ class ShuffleSuite extends FunSuite { assert(valuesFor2.toList.sorted === List(1)) sc.stop() } - + test("groupByKey with many output partitions") { val sc = new SparkContext("local", "test") val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) @@ -189,7 +189,7 @@ class ShuffleSuite extends FunSuite { )) sc.stop() } - + test("zero-partition RDD") { val sc = new SparkContext("local", "test") val emptyDir = Files.createTempDir() @@ -199,5 +199,5 @@ class ShuffleSuite extends FunSuite { // Test that a shuffle on the file works, because this used to be a bug assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) sc.stop() - } + } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index ea7e6ebbb1..63501f0613 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -9,6 +9,36 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter{ before { BlockManagerMaster.startBlockManagerMaster(true, true) } + + test("manager-master interaction") { + val store = new BlockManager(2000, new KryoSerializer) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + + // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_DESER) + store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_DESER) + store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_DESER, false) + + // Checking whether blocks are in memory + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") != None, "a3 was not in store") + + // Checking whether master knows about the blocks or not + assert(BlockManagerMaster.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") + assert(BlockManagerMaster.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2") + assert(BlockManagerMaster.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3") + + // Setting storage level of a1 and a2 to invalid; they should be removed from store and master + store.setLevel("a1", new StorageLevel(false, false, false, 1)) + store.setLevel("a2", new StorageLevel(true, false, false, 0)) + assert(store.getSingle("a1") === None, "a1 not removed from store") + assert(store.getSingle("a2") === None, "a2 not removed from store") + assert(BlockManagerMaster.mustGetLocations(GetLocations("a1")).size === 0, "master did not remove a1") + assert(BlockManagerMaster.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2") + } test("in-memory LRU storage") { val store = new BlockManager(1000, new KryoSerializer) @@ -21,14 +51,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter{ assert(store.getSingle("a2") != None, "a2 was not in store") assert(store.getSingle("a3") != None, "a3 was not in store") Thread.sleep(100) - assert(store.getSingle("a1") == None, "a1 was in store") + assert(store.getSingle("a1") === None, "a1 was in store") assert(store.getSingle("a2") != None, "a2 was not in store") // At this point a2 was gotten last, so LRU will getSingle rid of a3 store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_DESER) assert(store.getSingle("a1") != None, "a1 was not in store") assert(store.getSingle("a2") != None, "a2 was not in store") Thread.sleep(100) - assert(store.getSingle("a3") == None, "a3 was in store") + assert(store.getSingle("a3") === None, "a3 was in store") } test("in-memory LRU storage with serialization") { @@ -42,16 +72,16 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter{ Thread.sleep(100) assert(store.getSingle("a2") != None, "a2 was not in store") assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.getSingle("a1") == None, "a1 was in store") + assert(store.getSingle("a1") === None, "a1 was in store") assert(store.getSingle("a2") != None, "a2 was not in store") // At this point a2 was gotten last, so LRU will getSingle rid of a3 store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_DESER) Thread.sleep(100) assert(store.getSingle("a1") != None, "a1 was not in store") assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") == None, "a1 was in store") + assert(store.getSingle("a3") === None, "a1 was in store") } - + test("on-disk storage") { val store = new BlockManager(1000, new KryoSerializer) val a1 = new Array[Byte](400) @@ -132,7 +162,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter{ assert(store.get("list2").get.size == 2) assert(store.get("list3") != None, "list3 was not in store") assert(store.get("list3").get.size == 2) - assert(store.get("list1") == None, "list1 was in store") + assert(store.get("list1") === None, "list1 was in store") assert(store.get("list2") != None, "list2 was not in store") assert(store.get("list2").get.size == 2) // At this point list2 was gotten last, so LRU will getSingle rid of list3 @@ -142,7 +172,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter{ assert(store.get("list1").get.size == 2) assert(store.get("list2") != None, "list2 was not in store") assert(store.get("list2").get.size == 2) - assert(store.get("list3") == None, "list1 was in store") + assert(store.get("list3") === None, "list1 was in store") } test("LRU with mixed storage levels and streams") { @@ -158,25 +188,25 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter{ Thread.sleep(100) // At this point LRU should not kick in because list3 is only on disk assert(store.get("list1") != None, "list2 was not in store") - assert(store.get("list1").get.size == 2) + assert(store.get("list1").get.size === 2) assert(store.get("list2") != None, "list3 was not in store") - assert(store.get("list2").get.size == 2) + assert(store.get("list2").get.size === 2) assert(store.get("list3") != None, "list1 was not in store") - assert(store.get("list3").get.size == 2) + assert(store.get("list3").get.size === 2) assert(store.get("list1") != None, "list2 was not in store") - assert(store.get("list1").get.size == 2) + assert(store.get("list1").get.size === 2) assert(store.get("list2") != None, "list3 was not in store") - assert(store.get("list2").get.size == 2) + assert(store.get("list2").get.size === 2) assert(store.get("list3") != None, "list1 was not in store") - assert(store.get("list3").get.size == 2) + assert(store.get("list3").get.size === 2) // Now let's add in list4, which uses both disk and memory; list1 should drop out store.put("list4", list4.iterator, StorageLevel.DISK_AND_MEMORY) - assert(store.get("list1") == None, "list1 was in store") + assert(store.get("list1") === None, "list1 was in store") assert(store.get("list2") != None, "list3 was not in store") - assert(store.get("list2").get.size == 2) + assert(store.get("list2").get.size === 2) assert(store.get("list3") != None, "list1 was not in store") - assert(store.get("list3").get.size == 2) + assert(store.get("list3").get.size === 2) assert(store.get("list4") != None, "list4 was not in store") - assert(store.get("list4").get.size == 2) + assert(store.get("list4").get.size === 2) } } -- cgit v1.2.3