diff options
-rw-r--r-- | core/src/main/scala/spark/BlockRDD.scala | 42 | ||||
-rw-r--r-- | core/src/main/scala/spark/BoundedMemoryCache.scala | 10 | ||||
-rw-r--r-- | core/src/main/scala/spark/RDD.scala | 19 | ||||
-rw-r--r-- | core/src/main/scala/spark/SparkEnv.scala | 6 | ||||
-rw-r--r-- | core/src/main/scala/spark/scheduler/DAGScheduler.scala | 2 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockManager.scala | 36 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockManagerMaster.scala | 36 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockManagerWorker.scala | 6 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockMessage.scala | 20 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockMessageArray.scala | 6 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockStore.scala | 83 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/StorageLevel.scala | 23 | ||||
-rw-r--r-- | core/src/test/scala/spark/RDDSuite.scala | 7 | ||||
-rw-r--r-- | core/src/test/scala/spark/storage/BlockManagerSuite.scala | 40 | ||||
-rw-r--r-- | project/SparkBuild.scala | 57 |
15 files changed, 268 insertions, 125 deletions
diff --git a/core/src/main/scala/spark/BlockRDD.scala b/core/src/main/scala/spark/BlockRDD.scala new file mode 100644 index 0000000000..ea009f0f4f --- /dev/null +++ b/core/src/main/scala/spark/BlockRDD.scala @@ -0,0 +1,42 @@ +package spark + +import scala.collection.mutable.HashMap + +class BlockRDDSplit(val blockId: String, idx: Int) extends Split { + val index = idx +} + + +class BlockRDD[T: ClassManifest](sc: SparkContext, blockIds: Array[String]) extends RDD[T](sc) { + + @transient + val splits_ = (0 until blockIds.size).map(i => { + new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] + }).toArray + + @transient + lazy val locations_ = { + val blockManager = SparkEnv.get.blockManager + /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ + val locations = blockManager.getLocations(blockIds) + HashMap(blockIds.zip(locations):_*) + } + + override def splits = splits_ + + override def compute(split: Split): Iterator[T] = { + val blockManager = SparkEnv.get.blockManager + val blockId = split.asInstanceOf[BlockRDDSplit].blockId + blockManager.get(blockId) match { + case Some(block) => block.asInstanceOf[Iterator[T]] + case None => + throw new Exception("Could not compute split, block " + blockId + " not found") + } + } + + override def preferredLocations(split: Split) = + locations_(split.asInstanceOf[BlockRDDSplit].blockId) + + override val dependencies: List[Dependency[_]] = Nil +} + diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala index fa5dcee7bb..6fe0b94297 100644 --- a/core/src/main/scala/spark/BoundedMemoryCache.scala +++ b/core/src/main/scala/spark/BoundedMemoryCache.scala @@ -91,7 +91,15 @@ class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging { protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size)) // TODO: remove BoundedMemoryCache - SparkEnv.get.cacheTracker.dropEntry(datasetId.asInstanceOf[(Int, Int)]._2, partition) + + val (keySpaceId, innerDatasetId) = datasetId.asInstanceOf[(Any, Any)] + innerDatasetId match { + case rddId: Int => + SparkEnv.get.cacheTracker.dropEntry(rddId, partition) + case broadcastUUID: java.util.UUID => + // TODO: Maybe something should be done if the broadcasted variable falls out of cache + case _ => + } } } diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 429e9c936f..8a79e85cf9 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -94,6 +94,25 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def getStorageLevel = storageLevel + def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER): RDD[T] = { + if (!level.useDisk && level.replication < 2) { + throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") + } + + // This is a hack. Ideally this should re-use the code used by the CacheTracker + // to generate the key. + def getSplitKey(split: Split) = "rdd:%d:%d".format(this.id, split.index) + + persist(level) + sc.runJob(this, (iter: Iterator[T]) => {} ) + + val p = this.partitioner + + new BlockRDD[T](sc, splits.map(getSplitKey).toArray) { + override val partitioner = p + } + } + // Read this RDD; will read from cache if applicable, or otherwise compute final def iterator(split: Split): Iterator[T] = { if (storageLevel != StorageLevel.NONE) { diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 25593c596b..694db6b2a3 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -31,7 +31,7 @@ class SparkEnv ( shuffleFetcher.stop() shuffleManager.stop() blockManager.stop() - BlockManagerMaster.stopBlockManagerMaster() + blockManager.master.stop() actorSystem.shutdown() actorSystem.awaitTermination() } @@ -66,9 +66,9 @@ object SparkEnv { val serializerClass = System.getProperty("spark.serializer", "spark.KryoSerializer") val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer] - BlockManagerMaster.startBlockManagerMaster(actorSystem, isMaster, isLocal) + val blockManagerMaster = new BlockManagerMaster(actorSystem, isMaster, isLocal) - val blockManager = new BlockManager(serializer) + val blockManager = new BlockManager(blockManagerMaster, serializer) val connectionManager = blockManager.connectionManager diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 436c16cddd..f7472971b5 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -498,7 +498,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with if (!deadHosts.contains(host)) { logInfo("Host lost: " + host) deadHosts += host - BlockManagerMaster.notifyADeadHost(host) + env.blockManager.master.notifyADeadHost(host) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnHost(host) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 15131960d6..5067601198 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -5,7 +5,7 @@ import java.nio._ import java.nio.channels.FileChannel.MapMode import java.util.{HashMap => JHashMap} import java.util.LinkedHashMap -import java.util.UUID +import java.util.concurrent.ConcurrentHashMap import java.util.Collections import scala.actors._ @@ -66,15 +66,15 @@ class BlockLocker(numLockers: Int) { } - -class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging { +class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, maxMemory: Long) + extends Logging { case class BlockInfo(level: StorageLevel, tellMaster: Boolean) private val NUM_LOCKS = 337 private val locker = new BlockLocker(NUM_LOCKS) - private val blockInfo = Collections.synchronizedMap(new JHashMap[String, BlockInfo]) + private val blockInfo = new ConcurrentHashMap[String, BlockInfo]() private val memoryStore: BlockStore = new MemoryStore(this, maxMemory) private val diskStore: BlockStore = new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) @@ -94,15 +94,16 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging /** * Construct a BlockManager with a memory limit set based on system properties. */ - def this(serializer: Serializer) = - this(BlockManager.getMaxMemoryFromSystemProperties(), serializer) + def this(master: BlockManagerMaster, serializer: Serializer) = { + this(master, serializer, BlockManager.getMaxMemoryFromSystemProperties) + } /** * Initialize the BlockManager. Register to the BlockManagerMaster, and start the * BlockManagerWorker actor. */ private def initialize() { - BlockManagerMaster.mustRegisterBlockManager( + master.mustRegisterBlockManager( RegisterBlockManager(blockManagerId, maxMemory, maxMemory)) BlockManagerWorker.startBlockManagerWorker(this) } @@ -154,7 +155,7 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging */ def getLocations(blockId: String): Seq[String] = { val startTimeMs = System.currentTimeMillis - var managers = BlockManagerMaster.mustGetLocations(GetLocations(blockId)) + var managers = master.mustGetLocations(GetLocations(blockId)) val locations = managers.map(_.ip) logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs)) return locations @@ -165,7 +166,7 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging */ def getLocations(blockIds: Array[String]): Array[Seq[String]] = { val startTimeMs = System.currentTimeMillis - val locations = BlockManagerMaster.mustGetLocationsMultipleBlockIds( + val locations = master.mustGetLocationsMultipleBlockIds( GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) return locations @@ -235,7 +236,7 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging } logDebug("Getting remote block " + blockId) // Get locations of block - val locations = BlockManagerMaster.mustGetLocations(GetLocations(blockId)) + val locations = master.mustGetLocations(GetLocations(blockId)) // Get block from remote locations for (loc <- locations) { @@ -321,8 +322,8 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { throw new BlockException(oneBlockId, "Unexpected message received from " + cmId) } - val buffer = blockMessage.getData() - val blockId = blockMessage.getId() + val buffer = blockMessage.getData + val blockId = blockMessage.getId val block = dataDeserialize(buffer) blocks.update(blockId, Some(block)) logDebug("Got remote block " + blockId + " in " + Utils.getUsedTimeMs(startTime)) @@ -490,7 +491,7 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { val tLevel: StorageLevel = new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) - var peers = BlockManagerMaster.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + var peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) for (peer: BlockManagerId <- peers) { val start = System.nanoTime logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " @@ -564,7 +565,7 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging } private def notifyMaster(heartBeat: HeartBeat) { - BlockManagerMaster.mustHeartBeat(heartBeat) + master.mustHeartBeat(heartBeat) } def stop() { @@ -576,12 +577,9 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging } } - -object BlockManager extends Logging { +object BlockManager { def getMaxMemoryFromSystemProperties(): Long = { val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble - val bytes = (Runtime.getRuntime.totalMemory * memoryFraction).toLong - logInfo("Maximum memory to use: " + bytes) - bytes + (Runtime.getRuntime.totalMemory * memoryFraction).toLong } } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 97a5b0cb45..9f03c5a32c 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -85,7 +85,7 @@ case class RemoveHost(host: String) extends ToBlockManagerMaster case object StopBlockManagerMaster extends ToBlockManagerMaster -class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { +class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { class BlockManagerInfo( timeMs: Long, @@ -134,19 +134,19 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { } } - def getLastSeenMs(): Long = { + def getLastSeenMs: Long = { return lastSeenMs } - def getRemainedMem(): Long = { + def getRemainedMem: Long = { return remainedMem } - def getRemainedDisk(): Long = { + def getRemainedDisk: Long = { return remainedDisk } - override def toString(): String = { + override def toString: String = { return "BlockManagerInfo " + timeMs + " " + remainedMem + " " + remainedDisk } @@ -329,8 +329,8 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { } } -object BlockManagerMaster extends Logging { - initLogging() +class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean) + extends Logging { val AKKA_ACTOR_NAME: String = "BlockMasterManager" val REQUEST_RETRY_INTERVAL_MS = 100 @@ -342,20 +342,18 @@ object BlockManagerMaster extends Logging { val timeout = 10.seconds var masterActor: ActorRef = null - def startBlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean) { - if (isMaster) { - masterActor = actorSystem.actorOf( - Props(new BlockManagerMaster(isLocal)), name = AKKA_ACTOR_NAME) - logInfo("Registered BlockManagerMaster Actor") - } else { - val url = "akka://spark@%s:%s/user/%s".format( - DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT, AKKA_ACTOR_NAME) - logInfo("Connecting to BlockManagerMaster: " + url) - masterActor = actorSystem.actorFor(url) - } + if (isMaster) { + masterActor = actorSystem.actorOf( + Props(new BlockManagerMasterActor(isLocal)), name = AKKA_ACTOR_NAME) + logInfo("Registered BlockManagerMaster Actor") + } else { + val url = "akka://spark@%s:%s/user/%s".format( + DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT, AKKA_ACTOR_NAME) + logInfo("Connecting to BlockManagerMaster: " + url) + masterActor = actorSystem.actorFor(url) } - def stopBlockManagerMaster() { + def stop() { if (masterActor != null) { communicate(StopBlockManagerMaster) masterActor = null diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala index 501183ab1f..c61e280252 100644 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala @@ -48,15 +48,15 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging { } def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { - blockMessage.getType() match { + blockMessage.getType match { case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel()) + val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) logInfo("Received [" + pB + "]") putBlock(pB.id, pB.data, pB.level) return None } case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId()) + val gB = new GetBlock(blockMessage.getId) logInfo("Received [" + gB + "]") val buffer = getBlock(gB.id) if (buffer == null) { diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala index bb128dce7a..0b2ed69e07 100644 --- a/core/src/main/scala/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/spark/storage/BlockMessage.scala @@ -102,23 +102,23 @@ class BlockMessage() extends Logging{ set(buffer) } - def getType(): Int = { + def getType: Int = { return typ } - def getId(): String = { + def getId: String = { return id } - def getData(): ByteBuffer = { + def getData: ByteBuffer = { return data } - def getLevel(): StorageLevel = { + def getLevel: StorageLevel = { return level } - def toBufferMessage(): BufferMessage = { + def toBufferMessage: BufferMessage = { val startTime = System.currentTimeMillis val buffers = new ArrayBuffer[ByteBuffer]() var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2) @@ -128,7 +128,7 @@ class BlockMessage() extends Logging{ buffers += buffer if (typ == BlockMessage.TYPE_PUT_BLOCK) { - buffer = ByteBuffer.allocate(8).putInt(level.toInt()).putInt(level.replication) + buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication) buffer.flip() buffers += buffer @@ -164,7 +164,7 @@ class BlockMessage() extends Logging{ return Message.createBufferMessage(buffers) } - override def toString(): String = { + override def toString: String = { "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + ", data = " + (if (data != null) data.remaining.toString else "null") + "]" } @@ -209,11 +209,11 @@ object BlockMessage { def main(args: Array[String]) { val B = new BlockMessage() B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.DISK_AND_MEMORY_2)) - val bMsg = B.toBufferMessage() + val bMsg = B.toBufferMessage val C = new BlockMessage() C.set(bMsg) - println(B.getId() + " " + B.getLevel()) - println(C.getId() + " " + C.getLevel()) + println(B.getId + " " + B.getLevel) + println(C.getId + " " + C.getLevel) } } diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala index 5f411d3488..a108ab653e 100644 --- a/core/src/main/scala/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/spark/storage/BlockMessageArray.scala @@ -123,13 +123,13 @@ object BlockMessageArray { val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) println("Converted back to block message array") newBlockMessageArray.foreach(blockMessage => { - blockMessage.getType() match { + blockMessage.getType match { case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel()) + val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) println(pB) } case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId()) + val gB = new GetBlock(blockMessage.getId) println(gB) } } diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala index 8672a5376e..17f4f51aa8 100644 --- a/core/src/main/scala/spark/storage/BlockStore.scala +++ b/core/src/main/scala/spark/storage/BlockStore.scala @@ -1,16 +1,15 @@ package spark.storage import spark.{Utils, Logging, Serializer, SizeEstimator} - import scala.collection.mutable.ArrayBuffer - import java.io.{File, RandomAccessFile} import java.nio.ByteBuffer import java.nio.channels.FileChannel.MapMode import java.util.{UUID, LinkedHashMap} import java.util.concurrent.Executors - +import java.util.concurrent.ConcurrentHashMap import it.unimi.dsi.fastutil.io._ +import java.util.concurrent.ArrayBlockingQueue /** * Abstract class to store blocks @@ -41,13 +40,29 @@ abstract class BlockStore(blockManager: BlockManager) extends Logging { class MemoryStore(blockManager: BlockManager, maxMemory: Long) extends BlockStore(blockManager) { - class Entry(var value: Any, val size: Long, val deserialized: Boolean) + case class Entry(value: Any, size: Long, deserialized: Boolean, var dropPending: Boolean = false) private val memoryStore = new LinkedHashMap[String, Entry](32, 0.75f, true) private var currentMemory = 0L - private val blockDropper = Executors.newSingleThreadExecutor() - + //private val blockDropper = Executors.newSingleThreadExecutor() + private val blocksToDrop = new ArrayBlockingQueue[String](10000, true) + private val blockDropper = new Thread("memory store - block dropper") { + override def run() { + try{ + while (true) { + val blockId = blocksToDrop.take() + logDebug("Block " + blockId + " ready to be dropped") + blockManager.dropFromMemory(blockId) + } + } catch { + case ie: InterruptedException => + logInfo("Shutting down block dropper") + } + } + } + blockDropper.start() + def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { if (level.deserialized) { bytes.rewind() @@ -124,41 +139,45 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) memoryStore.synchronized { memoryStore.clear() } - blockDropper.shutdown() + //blockDropper.shutdown() + blockDropper.interrupt() logInfo("MemoryStore cleared") } - private def drop(blockId: String) { - blockDropper.submit(new Runnable() { - def run() { - blockManager.dropFromMemory(blockId) - } - }) - } - private def ensureFreeSpace(space: Long) { logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( space, currentMemory, maxMemory)) - val droppedBlockIds = new ArrayBuffer[String]() - var droppedMemory = 0L - - memoryStore.synchronized { - val iter = memoryStore.entrySet().iterator() - while (maxMemory - (currentMemory - droppedMemory) < space && iter.hasNext) { - val pair = iter.next() - val blockId = pair.getKey - droppedBlockIds += blockId - droppedMemory += pair.getValue.size - logDebug("Decided to drop " + blockId) + if (maxMemory - currentMemory < space) { + + val selectedBlocks = new ArrayBuffer[String]() + var selectedMemory = 0L + + memoryStore.synchronized { + val iter = memoryStore.entrySet().iterator() + while (maxMemory - (currentMemory - selectedMemory) < space && iter.hasNext) { + val pair = iter.next() + val blockId = pair.getKey + val entry = pair.getValue() + if (!entry.dropPending) { + selectedBlocks += blockId + entry.dropPending = true + } + selectedMemory += pair.getValue.size + logDebug("Block " + blockId + " selected for dropping") + } + } + + logDebug("" + selectedBlocks.size + " new blocks selected for dropping, " + + blocksToDrop.size + " blocks pending") + var i = 0 + while (i < selectedBlocks.size) { + blocksToDrop.add(selectedBlocks(i)) + i += 1 } - } - - for (blockId <- droppedBlockIds) { - drop(blockId) + selectedBlocks.clear() } - droppedBlockIds.clear() - } + } } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index 693a679c4e..f067a2a6c5 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -11,11 +11,8 @@ class StorageLevel( // TODO: Also add fields for caching priority, dataset ID, and flushing. - def this(booleanInt: Int, replication: Int) { - this(((booleanInt & 4) != 0), - ((booleanInt & 2) != 0), - ((booleanInt & 1) != 0), - replication) + def this(flags: Int, replication: Int) { + this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } def this() = this(false, true, false) // For deserialization @@ -33,25 +30,25 @@ class StorageLevel( false } - def isValid() = ((useMemory || useDisk) && (replication > 0)) + def isValid = ((useMemory || useDisk) && (replication > 0)) - def toInt(): Int = { + def toInt: Int = { var ret = 0 if (useDisk) { - ret += 4 + ret |= 4 } if (useMemory) { - ret += 2 + ret |= 2 } if (deserialized) { - ret += 1 + ret |= 1 } return ret } override def writeExternal(out: ObjectOutput) { - out.writeByte(toInt().toByte) - out.writeByte(replication.toByte) + out.writeByte(toInt) + out.writeByte(replication) } override def readExternal(in: ObjectInput) { @@ -62,7 +59,7 @@ class StorageLevel( replication = in.readByte() } - override def toString(): String = + override def toString: String = "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) } diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 4a79c086e9..20638bba92 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -59,4 +59,11 @@ class RDDSuite extends FunSuite with BeforeAndAfter { val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps) assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } + + test("checkpointing") { + val sc = new SparkContext("local", "test") + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).flatMap(x => 1 to x).checkpoint() + assert(rdd.collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) + sc.stop() + } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 1ed5519d37..61decd81e6 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -5,27 +5,29 @@ import java.nio.ByteBuffer import akka.actor._ import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfterEach +import org.scalatest.BeforeAndAfter import spark.KryoSerializer import spark.util.ByteBufferInputStream -class BlockManagerSuite extends FunSuite with BeforeAndAfterEach { +class BlockManagerSuite extends FunSuite with BeforeAndAfter { var actorSystem: ActorSystem = null + var master: BlockManagerMaster = null - override def beforeEach() { + before { actorSystem = ActorSystem("test") - BlockManagerMaster.startBlockManagerMaster(actorSystem, true, true) + master = new BlockManagerMaster(actorSystem, true, true) } - override def afterEach() { + after { actorSystem.shutdown() actorSystem.awaitTermination() actorSystem = null + master = null } test("manager-master interaction") { - val store = new BlockManager(2000, new KryoSerializer) + val store = new BlockManager(master, new KryoSerializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -41,21 +43,21 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfterEach { assert(store.getSingle("a3") != None, "a3 was not in store") // Checking whether master knows about the blocks or not - assert(BlockManagerMaster.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") - assert(BlockManagerMaster.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2") - assert(BlockManagerMaster.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3") + assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") + assert(master.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2") + assert(master.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3") // Setting storage level of a1 and a2 to invalid; they should be removed from store and master store.setLevel("a1", new StorageLevel(false, false, false, 1)) store.setLevel("a2", new StorageLevel(true, false, false, 0)) assert(store.getSingle("a1") === None, "a1 not removed from store") assert(store.getSingle("a2") === None, "a2 not removed from store") - assert(BlockManagerMaster.mustGetLocations(GetLocations("a1")).size === 0, "master did not remove a1") - assert(BlockManagerMaster.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2") + assert(master.mustGetLocations(GetLocations("a1")).size === 0, "master did not remove a1") + assert(master.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2") } test("in-memory LRU storage") { - val store = new BlockManager(1000, new KryoSerializer) + val store = new BlockManager(master, new KryoSerializer, 1000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -76,7 +78,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfterEach { } test("in-memory LRU storage with serialization") { - val store = new BlockManager(1000, new KryoSerializer) + val store = new BlockManager(master, new KryoSerializer, 1000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -97,7 +99,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfterEach { } test("on-disk storage") { - val store = new BlockManager(1000, new KryoSerializer) + val store = new BlockManager(master, new KryoSerializer, 1000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -110,7 +112,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfterEach { } test("disk and memory storage") { - val store = new BlockManager(1000, new KryoSerializer) + val store = new BlockManager(master, new KryoSerializer, 1000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -124,7 +126,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfterEach { } test("disk and memory storage with serialization") { - val store = new BlockManager(1000, new KryoSerializer) + val store = new BlockManager(master, new KryoSerializer, 1000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -138,7 +140,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfterEach { } test("LRU with mixed storage levels") { - val store = new BlockManager(1000, new KryoSerializer) + val store = new BlockManager(master, new KryoSerializer, 1000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -164,7 +166,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfterEach { } test("in-memory LRU with streams") { - val store = new BlockManager(1000, new KryoSerializer) + val store = new BlockManager(master, new KryoSerializer, 1000) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -190,7 +192,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfterEach { } test("LRU with mixed storage levels and streams") { - val store = new BlockManager(1000, new KryoSerializer) + val store = new BlockManager(master, new KryoSerializer, 1000) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index d1445f2ade..684108677f 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -1,5 +1,7 @@ import sbt._ import Keys._ +import classpath.ClasspathUtilities.isArchive +import java.io.FileOutputStream import sbtassembly.Plugin._ import AssemblyKeys._ import twirl.sbt.TwirlPlugin._ @@ -70,12 +72,12 @@ object SparkBuild extends Build { "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1" ) - ) ++ assemblySettings ++ extraAssemblySettings ++ Seq(Twirl.settings: _*) + ) ++ assemblySettings ++ extraAssemblySettings ++ mergeSettings ++ Seq(Twirl.settings: _*) def replSettings = sharedSettings ++ Seq( name := "spark-repl", libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-compiler" % _) - ) ++ assemblySettings ++ extraAssemblySettings + ) ++ assemblySettings ++ extraAssemblySettings ++ mergeSettings def examplesSettings = sharedSettings ++ Seq( name := "spark-examples" @@ -83,6 +85,57 @@ object SparkBuild extends Build { def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") + // Fix for "No configuration setting found for key 'akka.version'" exception + // when running Spark from the jar generated by the "assembly" task; see + // http://letitcrash.com/post/21025950392/howto-sbt-assembly-vs-reference-conf + lazy val merge = TaskKey[File]("merge-reference", + "merge all reference.conf") + + lazy val mergeSettings: Seq[Project.Setting[_]] = Seq( + merge <<= (fullClasspath in assembly) map { + c => + // collect from all elements of the full classpath + val (libs, dirs) = + c map (_.data) partition (isArchive) + // goal is to simply concatenate files here + val dest = file("reference.conf") + val out = new FileOutputStream(dest) + val append = IO.transfer(_: File, out) + try { + // first collect from managed sources + (dirs * "reference.conf").get foreach append + // then from dependency jars by unzipping and + // collecting reference.conf if present + for (lib <- libs) { + IO withTemporaryDirectory { + dir => + IO.unzip(lib, dir, "reference.conf") + (dir * "reference.conf").get foreach append + } + } + // return merged file location as task result + dest + } finally { + out.close() + } + }, + + // get rid of the individual files from jars + excludedFiles in assembly <<= + (excludedFiles in assembly) { + (old) => (bases) => + old(bases) ++ (bases flatMap (base => + (base / "reference.conf").get)) + }, + + // tell sbt-assembly to include our merged file + assembledMappings in assembly <<= + (assembledMappings in assembly, merge) map { + (old, merged) => (f) => + old(f) :+(merged, "reference.conf") + } + ) + def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( mergeStrategy in assembly := { case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard |