diff options
Diffstat (limited to 'core')
40 files changed, 527 insertions, 369 deletions
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java index c4aa2669e0..8a09210245 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java @@ -21,6 +21,7 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundByteHandlerAdapter; +import org.apache.spark.storage.BlockId; abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { @@ -33,7 +34,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { } public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); - public abstract void handleError(String blockId); + public abstract void handleError(BlockId blockId); @Override public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java index d3d57a0255..cfd8132891 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java @@ -24,6 +24,7 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundMessageHandlerAdapter; import io.netty.channel.DefaultFileRegion; +import org.apache.spark.storage.BlockId; class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> { @@ -34,8 +35,9 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> { } @Override - public void messageReceived(ChannelHandlerContext ctx, String blockId) { - String path = pResolver.getAbsolutePath(blockId); + public void messageReceived(ChannelHandlerContext ctx, String blockIdString) { + BlockId blockId = BlockId.apply(blockIdString); + String path = pResolver.getAbsolutePath(blockId.name()); // if getFilePath returns null, close the channel if (path == null) { //ctx.close(); diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala index 908ff56a6b..f8af6b0fbe 100644 --- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator @@ -45,12 +45,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) } - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map { + val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map { case (address, splits) => - (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2))) + (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = { + def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = { val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { @@ -58,9 +58,8 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin block.asInstanceOf[Iterator[T]] } case None => { - val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r blockId match { - case regex(shufId, mapId, _) => + case ShuffleBlockId(shufId, mapId, _) => val address = statuses(mapId.toInt)._1 throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null) case _ => diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 4cf7eb96da..221bb68c61 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -18,7 +18,7 @@ package org.apache.spark import scala.collection.mutable.{ArrayBuffer, HashSet} -import org.apache.spark.storage.{BlockManager, StorageLevel} +import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, RDDBlockId} import org.apache.spark.rdd.RDD @@ -28,12 +28,12 @@ import org.apache.spark.rdd.RDD private[spark] class CacheManager(blockManager: BlockManager) extends Logging { /** Keys of RDD splits that are being computed/loaded. */ - private val loading = new HashSet[String]() + private val loading = new HashSet[RDDBlockId]() /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel) : Iterator[T] = { - val key = "rdd_%d_%d".format(rdd.id, split.index) + val key = RDDBlockId(rdd.id, split.index) logDebug("Looking for partition " + key) blockManager.get(key) match { case Some(values) => @@ -73,7 +73,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { if (context.runningLocally) { return computedValues } val elements = new ArrayBuffer[Any] elements ++= computedValues - blockManager.put(key, elements, storageLevel, true) + blockManager.put(key, elements, storageLevel, tellMaster = true) return elements.iterator.asInstanceOf[Iterator[T]] } finally { loading.synchronized { diff --git a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala index f82dea9f3a..b6c484bfe1 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.{ListBuffer, Map, Set} import scala.math import org.apache.spark._ -import org.apache.spark.storage.{BlockManager, StorageLevel} +import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.Utils private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) @@ -36,7 +36,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: def value = value_ - def blockId: String = BlockManager.toBroadcastId(id) + def blockId = BroadcastBlockId(id) MultiTracker.synchronized { SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index a4ceb0d6af..609464e38d 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -25,16 +25,15 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import org.apache.spark.{HttpServer, Logging, SparkEnv} import org.apache.spark.io.CompressionCodec -import org.apache.spark.storage.{BlockManager, StorageLevel} -import org.apache.spark.util.{MetadataCleanerType, Utils, MetadataCleaner, TimeStampedHashSet} - +import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { def value = value_ - def blockId: String = BlockManager.toBroadcastId(id) + def blockId = BroadcastBlockId(id) HttpBroadcast.synchronized { SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) @@ -121,7 +120,7 @@ private object HttpBroadcast extends Logging { } def write(id: Long, value: Any) { - val file = new File(broadcastDir, "broadcast-" + id) + val file = new File(broadcastDir, BroadcastBlockId(id).name) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) @@ -137,7 +136,7 @@ private object HttpBroadcast extends Logging { } def read[T](id: Long): T = { - val url = serverUri + "/broadcast-" + id + val url = serverUri + "/" + BroadcastBlockId(id).name val in = { if (compress) { compressionCodec.compressedInputStream(new URL(url).openStream()) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala index b664f28e42..e6674d49a7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala @@ -19,13 +19,11 @@ package org.apache.spark.broadcast import java.io._ import java.net._ -import java.util.{Comparator, Random, UUID} -import scala.collection.mutable.{ListBuffer, Map, Set} -import scala.math +import scala.collection.mutable.{ListBuffer, Set} import org.apache.spark._ -import org.apache.spark.storage.{BlockManager, StorageLevel} +import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.Utils private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) @@ -33,7 +31,7 @@ extends Broadcast[T](id) with Logging with Serializable { def value = value_ - def blockId = BlockManager.toBroadcastId(id) + def blockId = BroadcastBlockId(id) MultiTracker.synchronized { SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index acdb8d0343..eff0c0f274 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -27,7 +27,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.scheduler._ import org.apache.spark._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util.Utils /** @@ -173,7 +173,7 @@ private[spark] class Executor( val serializedResult = { if (serializedDirectResult.limit >= akkaFrameSize - 1024) { logInfo("Storing result for " + taskId + " in local BlockManager") - val blockId = "taskresult_" + taskId + val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) ser.serialize(new IndirectTaskResult[Any](blockId)) diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala index 3c29700920..1b9fa1e53a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala @@ -20,17 +20,18 @@ package org.apache.spark.network.netty import io.netty.buffer._ import org.apache.spark.Logging +import org.apache.spark.storage.{TestBlockId, BlockId} private[spark] class FileHeader ( val fileLen: Int, - val blockId: String) extends Logging { + val blockId: BlockId) extends Logging { lazy val buffer = { val buf = Unpooled.buffer() buf.capacity(FileHeader.HEADER_SIZE) buf.writeInt(fileLen) - buf.writeInt(blockId.length) - blockId.foreach((x: Char) => buf.writeByte(x)) + buf.writeInt(blockId.name.length) + blockId.name.foreach((x: Char) => buf.writeByte(x)) //padding the rest of header if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) @@ -57,18 +58,15 @@ private[spark] object FileHeader { for (i <- 1 to idLength) { idBuilder += buf.readByte().asInstanceOf[Char] } - val blockId = idBuilder.toString() + val blockId = BlockId(idBuilder.toString()) new FileHeader(length, blockId) } - - def main (args:Array[String]){ - - val header = new FileHeader(25,"block_0"); - val buf = header.buffer; - val newheader = FileHeader.create(buf); - System.out.println("id="+newheader.blockId+",size="+newheader.fileLen) - + def main (args:Array[String]) { + val header = new FileHeader(25, TestBlockId("my_block")) + val buf = header.buffer + val newHeader = FileHeader.create(buf) + System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala index 9493ccffd9..481ff8c3e0 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala @@ -27,12 +27,13 @@ import org.apache.spark.Logging import org.apache.spark.network.ConnectionManagerId import scala.collection.JavaConverters._ +import org.apache.spark.storage.BlockId private[spark] class ShuffleCopier extends Logging { - def getBlock(host: String, port: Int, blockId: String, - resultCollectCallback: (String, Long, ByteBuf) => Unit) { + def getBlock(host: String, port: Int, blockId: BlockId, + resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt @@ -41,7 +42,7 @@ private[spark] class ShuffleCopier extends Logging { try { fc.init() fc.connect(host, port) - fc.sendRequest(blockId) + fc.sendRequest(blockId.name) fc.waitForClose() fc.close() } catch { @@ -53,14 +54,14 @@ private[spark] class ShuffleCopier extends Logging { } } - def getBlock(cmId: ConnectionManagerId, blockId: String, - resultCollectCallback: (String, Long, ByteBuf) => Unit) { + def getBlock(cmId: ConnectionManagerId, blockId: BlockId, + resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) } def getBlocks(cmId: ConnectionManagerId, - blocks: Seq[(String, Long)], - resultCollectCallback: (String, Long, ByteBuf) => Unit) { + blocks: Seq[(BlockId, Long)], + resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { for ((blockId, size) <- blocks) { getBlock(cmId, blockId, resultCollectCallback) @@ -71,7 +72,7 @@ private[spark] class ShuffleCopier extends Logging { private[spark] object ShuffleCopier extends Logging { - private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit) + private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit) extends FileClientHandler with Logging { override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { @@ -79,14 +80,14 @@ private[spark] object ShuffleCopier extends Logging { resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) } - override def handleError(blockId: String) { + override def handleError(blockId: BlockId) { if (!isComplete) { resultCollectCallBack(blockId, -1, null) } } } - def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { + def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) { if (size != -1) { logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") } @@ -99,7 +100,7 @@ private[spark] object ShuffleCopier extends Logging { } val host = args(0) val port = args(1).toInt - val file = args(2) + val blockId = BlockId(args(2)) val threads = if (args.length > 3) args(3).toInt else 10 val copiers = Executors.newFixedThreadPool(80) @@ -107,12 +108,12 @@ private[spark] object ShuffleCopier extends Logging { Executors.callable(new Runnable() { def run() { val copier = new ShuffleCopier() - copier.getBlock(host, port, file, echoResultCollectCallBack) + copier.getBlock(host, port, blockId, echoResultCollectCallBack) } }) }).asJava copiers.invokeAll(tasks) - copiers.shutdown + copiers.shutdown() System.exit(0) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala index 0c5ded3145..1586dff254 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala @@ -21,7 +21,7 @@ import java.io.File import org.apache.spark.Logging import org.apache.spark.util.Utils -import org.apache.spark.storage.ShuffleBlockManager +import org.apache.spark.storage.BlockId private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging { @@ -54,8 +54,9 @@ private[spark] object ShuffleSender { val localDirs = args.drop(2).map(new File(_)) val pResovler = new PathResolver { - override def getAbsolutePath(blockId: String): String = { - if (!ShuffleBlockManager.isShuffle(blockId)) { + override def getAbsolutePath(blockIdString: String): String = { + val blockId = BlockId(blockIdString) + if (!blockId.isShuffle) { throw new Exception("Block " + blockId + " is not a shuffle block") } // Figure out which local directory it hashes to, and which subdirectory in that @@ -63,7 +64,7 @@ private[spark] object ShuffleSender { val dirId = hash % localDirs.length val subDirId = (hash / localDirs.length) % subDirsPerLocalDir val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) - val file = new File(subDir, blockId) + val file = new File(subDir, blockId.name) return file.getAbsolutePath } } diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index bca6956a18..44ea573a7c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -18,14 +18,14 @@ package org.apache.spark.rdd import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext} -import org.apache.spark.storage.BlockManager +import org.apache.spark.storage.{BlockId, BlockManager} -private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition { +private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition { val index = idx } private[spark] -class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) +class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { @transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4226617cfb..5c51852985 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -28,8 +28,8 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.storage.{BlockManager, BlockManagerMaster} -import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -156,7 +156,7 @@ class DAGScheduler( private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { if (!cacheLocs.contains(rdd.id)) { - val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray + val blockIds = rdd.partitions.indices.map(index=> RDDBlockId(rdd.id, index)).toArray[BlockId] val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster) cacheLocs(rdd.id) = blockIds.map { id => locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index db3954a9d3..7e468d0d67 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -24,13 +24,14 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.{SparkEnv} import java.nio.ByteBuffer import org.apache.spark.util.Utils +import org.apache.spark.storage.BlockId // Task result. Also contains updates to accumulator variables. private[spark] sealed trait TaskResult[T] /** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */ private[spark] -case class IndirectTaskResult[T](val blockId: String) extends TaskResult[T] with Serializable +case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable /** A TaskResult that contains the task's return value and accumulator updates. */ private[spark] diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index e936b1cfed..55b25f145a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -26,9 +26,8 @@ import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar} import org.apache.spark.{SerializableWritable, Logging} -import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock, StorageLevel} - import org.apache.spark.broadcast.HttpBroadcast +import org.apache.spark.storage.{GetBlock,GotBlock, PutBlock, StorageLevel, TestBlockId} /** * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. @@ -43,13 +42,14 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging val kryo = instantiator.newKryo() val classLoader = Thread.currentThread.getContextClassLoader + val blockId = TestBlockId("1") // Register some commonly used classes val toRegister: Seq[AnyRef] = Seq( ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY, - PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), - GotBlock("1", ByteBuffer.allocate(1)), - GetBlock("1"), + PutBlock(blockId, ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), + GotBlock(blockId, ByteBuffer.allocate(1)), + GetBlock(blockId), 1 to 10, 1 until 10, 1L to 10L, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockException.scala b/core/src/main/scala/org/apache/spark/storage/BlockException.scala index 290dbce4f5..0d0a2dadc7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockException.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockException.scala @@ -18,5 +18,5 @@ package org.apache.spark.storage private[spark] -case class BlockException(blockId: String, message: String) extends Exception(message) +case class BlockException(blockId: BlockId, message: String) extends Exception(message) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 3aeda3879d..e51c5b30a3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -47,7 +47,7 @@ import org.apache.spark.util.Utils */ private[storage] -trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] +trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging with BlockFetchTracker { def initialize() } @@ -57,20 +57,20 @@ private[storage] object BlockFetcherIterator { // A request to fetch one or more blocks, complete with their sizes - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { val size = blocks.map(_._2).sum } // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize // the block (since we want all deserializaton to happen in the calling thread); can also // represent a fetch failure if size == -1. - class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { + class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { def failed: Boolean = size == -1 } class BasicBlockFetcherIterator( private val blockManager: BlockManager, - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer) extends BlockFetcherIterator { @@ -92,12 +92,12 @@ object BlockFetcherIterator { // This represents the number of local blocks, also counting zero-sized blocks private var numLocal = 0 // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks - protected val localBlocksToFetch = new ArrayBuffer[String]() + protected val localBlocksToFetch = new ArrayBuffer[BlockId]() // This represents the number of remote blocks, also counting zero-sized blocks private var numRemote = 0 // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks - protected val remoteBlocksToFetch = new HashSet[String]() + protected val remoteBlocksToFetch = new HashSet[BlockId]() // A queue to hold our results. protected val results = new LinkedBlockingQueue[FetchResult] @@ -167,7 +167,7 @@ object BlockFetcherIterator { logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) val iterator = blockInfos.iterator var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(String, Long)] + var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() // Skip empty blocks @@ -183,7 +183,7 @@ object BlockFetcherIterator { // Add this FetchRequest remoteRequests += new FetchRequest(address, curBlocks) curRequestSize = 0 - curBlocks = new ArrayBuffer[(String, Long)] + curBlocks = new ArrayBuffer[(BlockId, Long)] } } // Add in the final request @@ -241,7 +241,7 @@ object BlockFetcherIterator { override def hasNext: Boolean = resultsGotten < _numBlocksToFetch - override def next(): (String, Option[Iterator[Any]]) = { + override def next(): (BlockId, Option[Iterator[Any]]) = { resultsGotten += 1 val startFetchWait = System.currentTimeMillis() val result = results.take() @@ -267,7 +267,7 @@ object BlockFetcherIterator { class NettyBlockFetcherIterator( blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer) extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { @@ -303,7 +303,7 @@ object BlockFetcherIterator { override protected def sendRequest(req: FetchRequest) { - def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) { + def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) { val fetchResult = new FetchResult(blockId, blockSize, () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) results.put(fetchResult) @@ -337,7 +337,7 @@ object BlockFetcherIterator { logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") } - override def next(): (String, Option[Iterator[Any]]) = { + override def next(): (BlockId, Option[Iterator[Any]]) = { resultsGotten += 1 val result = results.take() // If all the results has been retrieved, copiers will exit automatically diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala new file mode 100644 index 0000000000..c7efc67a4a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +/** + * Identifies a particular Block of data, usually associated with a single file. + * A Block can be uniquely identified by its filename, but each type of Block has a different + * set of keys which produce its unique name. + * + * If your BlockId should be serializable, be sure to add it to the BlockId.fromString() method. + */ +private[spark] sealed abstract class BlockId { + /** A globally unique identifier for this Block. Can be used for ser/de. */ + def name: String + + // convenience methods + def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None + def isRDD = isInstanceOf[RDDBlockId] + def isShuffle = isInstanceOf[ShuffleBlockId] + def isBroadcast = isInstanceOf[BroadcastBlockId] + + override def toString = name + override def hashCode = name.hashCode + override def equals(other: Any): Boolean = other match { + case o: BlockId => getClass == o.getClass && name.equals(o.name) + case _ => false + } +} + +private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { + def name = "rdd_" + rddId + "_" + splitIndex +} + +private[spark] +case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { + def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId +} + +private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { + def name = "broadcast_" + broadcastId +} + +private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId { + def name = "taskresult_" + taskId +} + +private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId { + def name = "input-" + streamId + "-" + uniqueId +} + +// Intended only for testing purposes +private[spark] case class TestBlockId(id: String) extends BlockId { + def name = "test_" + id +} + +private[spark] object BlockId { + val RDD = "rdd_([0-9]+)_([0-9]+)".r + val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r + val BROADCAST = "broadcast_([0-9]+)".r + val TASKRESULT = "taskresult_([0-9]+)".r + val STREAM = "input-([0-9]+)-([0-9]+)".r + val TEST = "test_(.*)".r + + /** Converts a BlockId "name" String back into a BlockId. */ + def apply(id: String) = id match { + case RDD(rddId, splitIndex) => + RDDBlockId(rddId.toInt, splitIndex.toInt) + case SHUFFLE(shuffleId, mapId, reduceId) => + ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) + case BROADCAST(broadcastId) => + BroadcastBlockId(broadcastId.toLong) + case TASKRESULT(taskId) => + TaskResultBlockId(taskId.toLong) + case STREAM(streamId, uniqueId) => + StreamBlockId(streamId.toInt, uniqueId.toLong) + case TEST(value) => + TestBlockId(value) + case _ => + throw new IllegalStateException("Unrecognized BlockId: " + id) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 2322922f75..801f88a3db 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -37,7 +37,6 @@ import org.apache.spark.util._ import sun.nio.ch.DirectBuffer - private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, @@ -103,7 +102,7 @@ private[spark] class BlockManager( val shuffleBlockManager = new ShuffleBlockManager(this) - private val blockInfo = new TimeStampedHashMap[String, BlockInfo] + private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) private[storage] val diskStore: DiskStore = @@ -249,7 +248,7 @@ private[spark] class BlockManager( /** * Get storage level of local block. If no info exists for the block, then returns null. */ - def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull + def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull /** * Tell the master about the current storage status of a block. This will send a block update @@ -259,7 +258,7 @@ private[spark] class BlockManager( * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid). * This ensures that update in master will compensate for the increase in memory on slave. */ - def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) { + def reportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L) { val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize) if (needReregister) { logInfo("Got told to reregister updating block " + blockId) @@ -274,7 +273,7 @@ private[spark] class BlockManager( * which will be true if the block was successfully recorded and false if * the slave needs to re-register. */ - private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { + private def tryToReportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { info.level match { case null => @@ -299,7 +298,7 @@ private[spark] class BlockManager( /** * Get locations of an array of blocks. */ - def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = { + def getLocationBlockIds(blockIds: Array[BlockId]): Array[Seq[BlockManagerId]] = { val startTimeMs = System.currentTimeMillis val locations = master.getLocations(blockIds).toArray logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) @@ -311,7 +310,7 @@ private[spark] class BlockManager( * shuffle blocks. It is safe to do so without a lock on block info since disk store * never deletes (recent) items. */ - def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { diskStore.getValues(blockId, serializer).orElse( sys.error("Block " + blockId + " not found on disk, though it should be")) } @@ -319,7 +318,7 @@ private[spark] class BlockManager( /** * Get block from local block manager. */ - def getLocal(blockId: String): Option[Iterator[Any]] = { + def getLocal(blockId: BlockId): Option[Iterator[Any]] = { logDebug("Getting local block " + blockId) val info = blockInfo.get(blockId).orNull if (info != null) { @@ -400,13 +399,13 @@ private[spark] class BlockManager( /** * Get block from the local block manager as serialized bytes. */ - def getLocalBytes(blockId: String): Option[ByteBuffer] = { + def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = { // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow logDebug("Getting local block " + blockId + " as bytes") // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work - if (ShuffleBlockManager.isShuffle(blockId)) { + if (blockId.isShuffle) { return diskStore.getBytes(blockId) match { case Some(bytes) => Some(bytes) @@ -473,7 +472,7 @@ private[spark] class BlockManager( /** * Get block from remote block managers. */ - def getRemote(blockId: String): Option[Iterator[Any]] = { + def getRemote(blockId: BlockId): Option[Iterator[Any]] = { if (blockId == null) { throw new IllegalArgumentException("Block Id is null") } @@ -498,7 +497,7 @@ private[spark] class BlockManager( /** * Get block from remote block managers as serialized bytes. */ - def getRemoteBytes(blockId: String): Option[ByteBuffer] = { + def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = { // TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be // refactored. if (blockId == null) { @@ -523,7 +522,7 @@ private[spark] class BlockManager( /** * Get a block from the block manager (either local or remote). */ - def get(blockId: String): Option[Iterator[Any]] = { + def get(blockId: BlockId): Option[Iterator[Any]] = { val local = getLocal(blockId) if (local.isDefined) { logInfo("Found block %s locally".format(blockId)) @@ -544,7 +543,7 @@ private[spark] class BlockManager( * so that we can control the maxMegabytesInFlight for the fetch. */ def getMultiple( - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer) : BlockFetcherIterator = { val iter = @@ -558,7 +557,7 @@ private[spark] class BlockManager( iter } - def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) + def put(blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) : Long = { val elements = new ArrayBuffer[Any] elements ++= values @@ -570,7 +569,7 @@ private[spark] class BlockManager( * This is currently used for writing shuffle files out. Callers should handle error * cases. */ - def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) + def getDiskBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize) writer.registerCloseEventHandler(() => { @@ -584,7 +583,7 @@ private[spark] class BlockManager( /** * Put a new block of values to the block manager. Returns its (estimated) size in bytes. */ - def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, + def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, tellMaster: Boolean = true) : Long = { if (blockId == null) { @@ -704,7 +703,7 @@ private[spark] class BlockManager( * Put a new block of serialized bytes to the block manager. */ def putBytes( - blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { + blockId: BlockId, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { if (blockId == null) { throw new IllegalArgumentException("Block Id is null") @@ -805,7 +804,7 @@ private[spark] class BlockManager( * Replicate block to another node. */ var cachedPeers: Seq[BlockManagerId] = null - private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { + private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel) { val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) if (cachedPeers == null) { cachedPeers = master.getPeers(blockManagerId, level.replication - 1) @@ -828,14 +827,14 @@ private[spark] class BlockManager( /** * Read a block consisting of a single object. */ - def getSingle(blockId: String): Option[Any] = { + def getSingle(blockId: BlockId): Option[Any] = { get(blockId).map(_.next()) } /** * Write a block consisting of a single object. */ - def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) { + def putSingle(blockId: BlockId, value: Any, level: StorageLevel, tellMaster: Boolean = true) { put(blockId, Iterator(value), level, tellMaster) } @@ -843,7 +842,7 @@ private[spark] class BlockManager( * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory * store reaches its limit and needs to free up space. */ - def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { + def dropFromMemory(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer]) { logInfo("Dropping block " + blockId + " from memory") val info = blockInfo.get(blockId).orNull if (info != null) { @@ -892,16 +891,15 @@ private[spark] class BlockManager( // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps // from RDD.id to blocks. logInfo("Removing RDD " + rddId) - val rddPrefix = "rdd_" + rddId + "_" - val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1) - blocksToRemove.foreach(blockId => removeBlock(blockId, false)) + val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false)) blocksToRemove.size } /** * Remove a block from both memory and disk. */ - def removeBlock(blockId: String, tellMaster: Boolean = true) { + def removeBlock(blockId: BlockId, tellMaster: Boolean = true) { logInfo("Removing block " + blockId) val info = blockInfo.get(blockId).orNull if (info != null) info.synchronized { @@ -928,7 +926,7 @@ private[spark] class BlockManager( while (iterator.hasNext) { val entry = iterator.next() val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) - if (time < cleanupTime && ! BlockManager.isBroadcastBlock(id) ) { + if (time < cleanupTime && !id.isBroadcast) { info.synchronized { val level = info.level if (level.useMemory) { @@ -951,7 +949,7 @@ private[spark] class BlockManager( while (iterator.hasNext) { val entry = iterator.next() val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) - if (time < cleanupTime && BlockManager.isBroadcastBlock(id) ) { + if (time < cleanupTime && id.isBroadcast) { info.synchronized { val level = info.level if (level.useMemory) { @@ -968,34 +966,29 @@ private[spark] class BlockManager( } } - def shouldCompress(blockId: String): Boolean = { - if (ShuffleBlockManager.isShuffle(blockId)) { - compressShuffle - } else if (BlockManager.isBroadcastBlock(blockId)) { - compressBroadcast - } else if (blockId.startsWith("rdd_")) { - compressRdds - } else { - false // Won't happen in a real cluster, but it can in tests - } + def shouldCompress(blockId: BlockId): Boolean = blockId match { + case ShuffleBlockId(_, _, _) => compressShuffle + case BroadcastBlockId(_) => compressBroadcast + case RDDBlockId(_, _) => compressRdds + case _ => false } /** * Wrap an output stream for compression if block compression is enabled for its block type */ - def wrapForCompression(blockId: String, s: OutputStream): OutputStream = { + def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } /** * Wrap an input stream for compression if block compression is enabled for its block type */ - def wrapForCompression(blockId: String, s: InputStream): InputStream = { + def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } def dataSerialize( - blockId: String, + blockId: BlockId, values: Iterator[Any], serializer: Serializer = defaultSerializer): ByteBuffer = { val byteStream = new FastByteArrayOutputStream(4096) @@ -1010,7 +1003,7 @@ private[spark] class BlockManager( * the iterator is reached. */ def dataDeserialize( - blockId: String, + blockId: BlockId, bytes: ByteBuffer, serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() @@ -1065,10 +1058,10 @@ private[spark] object BlockManager extends Logging { } def blockIdsToBlockManagers( - blockIds: Array[String], + blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[BlockManagerId]] = + : Map[BlockId, Seq[BlockManagerId]] = { // env == null and blockManagerMaster != null is used in tests assert (env != null || blockManagerMaster != null) @@ -1078,7 +1071,7 @@ private[spark] object BlockManager extends Logging { blockManagerMaster.getLocations(blockIds) } - val blockManagers = new HashMap[String, Seq[BlockManagerId]] + val blockManagers = new HashMap[BlockId, Seq[BlockManagerId]] for (i <- 0 until blockIds.length) { blockManagers(blockIds(i)) = blockLocations(i) } @@ -1086,25 +1079,21 @@ private[spark] object BlockManager extends Logging { } def blockIdsToExecutorIds( - blockIds: Array[String], + blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[String]] = + : Map[BlockId, Seq[String]] = { blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId)) } def blockIdsToHosts( - blockIds: Array[String], + blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[String]] = + : Map[BlockId, Seq[String]] = { blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host)) } - - def isBroadcastBlock(blockId: String): Boolean = null != blockId && blockId.startsWith("broadcast_") - - def toBroadcastId(id: Long): String = "broadcast_" + id } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index cf463d6ffc..94038649b3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -60,7 +60,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi def updateBlockInfo( blockManagerId: BlockManagerId, - blockId: String, + blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diskSize: Long): Boolean = { @@ -71,12 +71,12 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi } /** Get locations of the blockId from the driver */ - def getLocations(blockId: String): Seq[BlockManagerId] = { + def getLocations(blockId: BlockId): Seq[BlockManagerId] = { askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ - def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } @@ -94,7 +94,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi * Remove a block from the slaves that have it. This can only be used to remove * blocks that the driver knows about. */ - def removeBlock(blockId: String) { + def removeBlock(blockId: BlockId) { askDriverWithReply(RemoveBlock(blockId)) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index c7b23ab094..633230c0a8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -48,7 +48,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] // Mapping from block id to the set of block managers that have the block. - private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]] + private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] val akkaTimeout = Duration.create( System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") @@ -129,10 +129,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { // First remove the metadata for the given RDD, and then asynchronously remove the blocks // from the slaves. - val prefix = "rdd_" + rddId + "_" // Find all blocks for the given RDD, remove the block from both blockLocations and // the blockManagerInfo that is tracking the blocks. - val blocks = blockLocations.keySet().filter(_.startsWith(prefix)) + val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocks.foreach { blockId => val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) @@ -198,7 +197,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { // Remove a block from the slaves that have it. This can only be used to remove // blocks that the master knows about. - private def removeBlockFromWorkers(blockId: String) { + private def removeBlockFromWorkers(blockId: BlockId) { val locations = blockLocations.get(blockId) if (locations != null) { locations.foreach { blockManagerId: BlockManagerId => @@ -247,7 +246,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private def updateBlockInfo( blockManagerId: BlockManagerId, - blockId: String, + blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diskSize: Long) { @@ -292,11 +291,11 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { sender ! true } - private def getLocations(blockId: String): Seq[BlockManagerId] = { + private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty } - private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { blockIds.map(blockId => getLocations(blockId)) } @@ -330,7 +329,7 @@ object BlockManagerMasterActor { private var _remainingMem: Long = maxMem // Mapping from block id to its status. - private val _blocks = new JHashMap[String, BlockStatus] + private val _blocks = new JHashMap[BlockId, BlockStatus] logInfo("Registering block manager %s with %s RAM".format( blockManagerId.hostPort, Utils.bytesToString(maxMem))) @@ -339,7 +338,7 @@ object BlockManagerMasterActor { _lastSeenMs = System.currentTimeMillis() } - def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, + def updateBlockInfo(blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diskSize: Long) { updateLastSeenMs() @@ -383,7 +382,7 @@ object BlockManagerMasterActor { } } - def removeBlock(blockId: String) { + def removeBlock(blockId: BlockId) { if (_blocks.containsKey(blockId)) { _remainingMem += _blocks.get(blockId).memSize _blocks.remove(blockId) @@ -394,7 +393,7 @@ object BlockManagerMasterActor { def lastSeenMs: Long = _lastSeenMs - def blocks: JHashMap[String, BlockStatus] = _blocks + def blocks: JHashMap[BlockId, BlockStatus] = _blocks override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 24333a179c..45f51da288 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -30,7 +30,7 @@ private[storage] object BlockManagerMessages { // Remove a block from the slaves that have it. This can only be used to remove // blocks that the master knows about. - case class RemoveBlock(blockId: String) extends ToBlockManagerSlave + case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave // Remove all blocks belonging to a specific RDD. case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave @@ -51,7 +51,7 @@ private[storage] object BlockManagerMessages { class UpdateBlockInfo( var blockManagerId: BlockManagerId, - var blockId: String, + var blockId: BlockId, var storageLevel: StorageLevel, var memSize: Long, var diskSize: Long) @@ -62,7 +62,7 @@ private[storage] object BlockManagerMessages { override def writeExternal(out: ObjectOutput) { blockManagerId.writeExternal(out) - out.writeUTF(blockId) + out.writeUTF(blockId.name) storageLevel.writeExternal(out) out.writeLong(memSize) out.writeLong(diskSize) @@ -70,7 +70,7 @@ private[storage] object BlockManagerMessages { override def readExternal(in: ObjectInput) { blockManagerId = BlockManagerId(in) - blockId = in.readUTF() + blockId = BlockId(in.readUTF()) storageLevel = StorageLevel(in) memSize = in.readLong() diskSize = in.readLong() @@ -79,7 +79,7 @@ private[storage] object BlockManagerMessages { object UpdateBlockInfo { def apply(blockManagerId: BlockManagerId, - blockId: String, + blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diskSize: Long): UpdateBlockInfo = { @@ -87,14 +87,14 @@ private[storage] object BlockManagerMessages { } // For pattern-matching - def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { + def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, BlockId, StorageLevel, Long, Long)] = { Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) } } - case class GetLocations(blockId: String) extends ToBlockManagerMaster + case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster - case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster + case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index 678c38203c..0c66addf9d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -77,7 +77,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends } } - private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) { + private def putBlock(id: BlockId, bytes: ByteBuffer, level: StorageLevel) { val startTimeMs = System.currentTimeMillis() logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) blockManager.putBytes(id, bytes, level) @@ -85,7 +85,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends + " with data size: " + bytes.limit) } - private def getBlock(id: String): ByteBuffer = { + private def getBlock(id: BlockId): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("GetBlock " + id + " started from " + startTimeMs) val buffer = blockManager.getLocalBytes(id) match { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala index d8fa6a91d1..80dcb5a207 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.network._ -private[spark] case class GetBlock(id: String) -private[spark] case class GotBlock(id: String, data: ByteBuffer) -private[spark] case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) +private[spark] case class GetBlock(id: BlockId) +private[spark] case class GotBlock(id: BlockId, data: ByteBuffer) +private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel) private[spark] class BlockMessage() { // Un-initialized: typ = 0 @@ -34,7 +34,7 @@ private[spark] class BlockMessage() { // GotBlock: typ = 2 // PutBlock: typ = 3 private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED - private var id: String = null + private var id: BlockId = null private var data: ByteBuffer = null private var level: StorageLevel = null @@ -74,7 +74,7 @@ private[spark] class BlockMessage() { for (i <- 1 to idLength) { idBuilder += buffer.getChar() } - id = idBuilder.toString() + id = BlockId(idBuilder.toString) if (typ == BlockMessage.TYPE_PUT_BLOCK) { @@ -109,28 +109,17 @@ private[spark] class BlockMessage() { set(buffer) } - def getType: Int = { - return typ - } - - def getId: String = { - return id - } - - def getData: ByteBuffer = { - return data - } - - def getLevel: StorageLevel = { - return level - } + def getType: Int = typ + def getId: BlockId = id + def getData: ByteBuffer = data + def getLevel: StorageLevel = level def toBufferMessage: BufferMessage = { val startTime = System.currentTimeMillis val buffers = new ArrayBuffer[ByteBuffer]() - var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2) - buffer.putInt(typ).putInt(id.length()) - id.foreach((x: Char) => buffer.putChar(x)) + var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2) + buffer.putInt(typ).putInt(id.name.length) + id.name.foreach((x: Char) => buffer.putChar(x)) buffer.flip() buffers += buffer @@ -212,7 +201,8 @@ private[spark] object BlockMessage { def main(args: Array[String]) { val B = new BlockMessage() - B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) + val blockId = TestBlockId("ABC") + B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) val bMsg = B.toBufferMessage val C = new BlockMessage() C.set(bMsg) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala index 0aaf846b5b..6ce9127c74 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala @@ -111,14 +111,15 @@ private[spark] object BlockMessageArray { } def main(args: Array[String]) { - val blockMessages = + val blockMessages = (0 until 10).map { i => if (i % 2 == 0) { val buffer = ByteBuffer.allocate(100) buffer.clear - BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY_SER)) + BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, + StorageLevel.MEMORY_ONLY_SER)) } else { - BlockMessage.fromGetBlock(GetBlock(i.toString)) + BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString))) } } val blockMessageArray = new BlockMessageArray(blockMessages) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 39f103297f..2a67800c45 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -25,7 +25,7 @@ package org.apache.spark.storage * * This interface does not support concurrent writes. */ -abstract class BlockObjectWriter(val blockId: String) { +abstract class BlockObjectWriter(val blockId: BlockId) { var closeEventHandler: () => Unit = _ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala index fa834371f4..ea42656240 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -27,7 +27,7 @@ import org.apache.spark.Logging */ private[spark] abstract class BlockStore(val blockManager: BlockManager) extends Logging { - def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) + def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) /** * Put in a block and, possibly, also return its content as either bytes or another Iterator. @@ -36,26 +36,26 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging { * @return a PutResult that contains the size of the data, as well as the values put if * returnValues is true (if not, the result's data field can be null) */ - def putValues(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, + def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) : PutResult /** * Return the size of a block in bytes. */ - def getSize(blockId: String): Long + def getSize(blockId: BlockId): Long - def getBytes(blockId: String): Option[ByteBuffer] + def getBytes(blockId: BlockId): Option[ByteBuffer] - def getValues(blockId: String): Option[Iterator[Any]] + def getValues(blockId: BlockId): Option[Iterator[Any]] /** * Remove a block, if it exists. * @param blockId the block to remove. * @return True if the block was found and removed, False otherwise. */ - def remove(blockId: String): Boolean + def remove(blockId: BlockId): Boolean - def contains(blockId: String): Boolean + def contains(blockId: BlockId): Boolean def clear() { } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 63447baf8c..b7ca61e938 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.Utils private class DiskStore(blockManager: BlockManager, rootDirs: String) extends BlockStore(blockManager) with Logging { - class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int) + class DiskBlockObjectWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int) extends BlockObjectWriter(blockId) { private val f: File = createFile(blockId /*, allowAppendExisting */) @@ -124,16 +124,16 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) addShutdownHook() - def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) + def getBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { new DiskBlockObjectWriter(blockId, serializer, bufferSize) } - override def getSize(blockId: String): Long = { + override def getSize(blockId: BlockId): Long = { getFile(blockId).length() } - override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) { // So that we do not modify the input offsets ! // duplicate does not copy buffer, so inexpensive val bytes = _bytes.duplicate() @@ -163,7 +163,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } override def putValues( - blockId: String, + blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) @@ -192,13 +192,13 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } } - override def getBytes(blockId: String): Option[ByteBuffer] = { + override def getBytes(blockId: BlockId): Option[ByteBuffer] = { val file = getFile(blockId) val bytes = getFileBytes(file) Some(bytes) } - override def getValues(blockId: String): Option[Iterator[Any]] = { + override def getValues(blockId: BlockId): Option[Iterator[Any]] = { getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) } @@ -206,11 +206,11 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) * A version of getValues that allows a custom serializer. This is used as part of the * shuffle short-circuit code. */ - def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) } - override def remove(blockId: String): Boolean = { + override def remove(blockId: BlockId): Boolean = { val file = getFile(blockId) if (file.exists()) { file.delete() @@ -219,11 +219,11 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } } - override def contains(blockId: String): Boolean = { + override def contains(blockId: BlockId): Boolean = { getFile(blockId).exists() } - private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { + private def createFile(blockId: BlockId, allowAppendExisting: Boolean = false): File = { val file = getFile(blockId) if (!allowAppendExisting && file.exists()) { // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task @@ -234,7 +234,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) file } - private def getFile(blockId: String): File = { + private def getFile(blockId: BlockId): File = { logDebug("Getting file for block " + blockId) // Figure out which local directory it hashes to, and which subdirectory in that @@ -258,7 +258,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } } - new File(subDir, blockId) + new File(subDir, blockId.name) } private def createLocalDirs(): Array[File] = { @@ -307,7 +307,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } } if (shuffleSender != null) { - shuffleSender.stop + shuffleSender.stop() } } }) @@ -315,11 +315,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private[storage] def startShuffleBlockSender(port: Int): Int = { val pResolver = new PathResolver { - override def getAbsolutePath(blockId: String): String = { - if (!blockId.startsWith("shuffle_")) { - return null - } - DiskStore.this.getFile(blockId).getAbsolutePath() + override def getAbsolutePath(blockIdString: String): String = { + val blockId = BlockId(blockIdString) + if (!blockId.isShuffle) null + else DiskStore.this.getFile(blockId).getAbsolutePath } } shuffleSender = new ShuffleSender(port, pResolver) diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 77a39c71ed..05f676c6e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -32,7 +32,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) case class Entry(value: Any, size: Long, deserialized: Boolean) - private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true) + private val entries = new LinkedHashMap[BlockId, Entry](32, 0.75f, true) @volatile private var currentMemory = 0L // Object used to ensure that only one thread is putting blocks and if necessary, dropping // blocks from the memory store. @@ -42,13 +42,13 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) def freeMemory: Long = maxMemory - currentMemory - override def getSize(blockId: String): Long = { + override def getSize(blockId: BlockId): Long = { entries.synchronized { entries.get(blockId).size } } - override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) { // Work on a duplicate - since the original input might be used elsewhere. val bytes = _bytes.duplicate() bytes.rewind() @@ -64,7 +64,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def putValues( - blockId: String, + blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) @@ -81,7 +81,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def getBytes(blockId: String): Option[ByteBuffer] = { + override def getBytes(blockId: BlockId): Option[ByteBuffer] = { val entry = entries.synchronized { entries.get(blockId) } @@ -94,7 +94,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def getValues(blockId: String): Option[Iterator[Any]] = { + override def getValues(blockId: BlockId): Option[Iterator[Any]] = { val entry = entries.synchronized { entries.get(blockId) } @@ -108,7 +108,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def remove(blockId: String): Boolean = { + override def remove(blockId: BlockId): Boolean = { entries.synchronized { val entry = entries.remove(blockId) if (entry != null) { @@ -131,14 +131,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } /** - * Return the RDD ID that a given block ID is from, or null if it is not an RDD block. + * Return the RDD ID that a given block ID is from, or None if it is not an RDD block. */ - private def getRddId(blockId: String): String = { - if (blockId.startsWith("rdd_")) { - blockId.split('_')(1) - } else { - null - } + private def getRddId(blockId: BlockId): Option[Int] = { + blockId.asRDDId.map(_.rddId) } /** @@ -151,7 +147,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * blocks to free memory for one block, another thread may use up the freed space for * another block. */ - private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = { + private def tryToPut(blockId: BlockId, value: Any, size: Long, deserialized: Boolean): Boolean = { // TODO: Its possible to optimize the locking by locking entries only when selecting blocks // to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been // released, it must be ensured that those to-be-dropped blocks are not double counted for @@ -195,7 +191,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Assumes that a lock is held by the caller to ensure only one thread is dropping blocks. * Otherwise, the freed space may fill up before the caller puts in their new value. */ - private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = { + private def ensureFreeSpace(blockIdToAdd: BlockId, space: Long): Boolean = { logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( space, currentMemory, maxMemory)) @@ -207,7 +203,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) if (maxMemory - currentMemory < space) { val rddToAdd = getRddId(blockIdToAdd) - val selectedBlocks = new ArrayBuffer[String]() + val selectedBlocks = new ArrayBuffer[BlockId]() var selectedMemory = 0L // This is synchronized to ensure that the set of entries is not changed @@ -218,7 +214,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey - if (rddToAdd != null && rddToAdd == getRddId(blockId)) { + if (rddToAdd != None && rddToAdd == getRddId(blockId)) { logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + "block from the same RDD") return false @@ -252,7 +248,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) return true } - override def contains(blockId: String): Boolean = { + override def contains(blockId: BlockId): Boolean = { entries.synchronized { entries.containsKey(blockId) } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 9da11efb57..f39fcd87fb 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -30,7 +30,6 @@ trait ShuffleBlocks { def releaseWriters(group: ShuffleWriterGroup) } - private[spark] class ShuffleBlockManager(blockManager: BlockManager) { @@ -40,7 +39,7 @@ class ShuffleBlockManager(blockManager: BlockManager) { override def acquireWriters(mapId: Int): ShuffleWriterGroup = { val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) + val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) } new ShuffleWriterGroup(mapId, writers) @@ -52,16 +51,3 @@ class ShuffleBlockManager(blockManager: BlockManager) { } } } - - -private[spark] -object ShuffleBlockManager { - - // Returns the block id for a given shuffle block. - def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = { - "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId - } - - // Returns true if the block is a shuffle block. - def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_") -} diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 2bb7715696..1720007e4e 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -23,20 +23,24 @@ import org.apache.spark.util.Utils private[spark] case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, - blocks: Map[String, BlockStatus]) { + blocks: Map[BlockId, BlockStatus]) { - def memUsed(blockPrefix: String = "") = { - blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize). - reduceOption(_+_).getOrElse(0l) - } + def memUsed() = blocks.values.map(_.memSize).reduceOption(_+_).getOrElse(0L) - def diskUsed(blockPrefix: String = "") = { - blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize). - reduceOption(_+_).getOrElse(0l) - } + def memUsedByRDD(rddId: Int) = + rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_+_).getOrElse(0L) + + def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_+_).getOrElse(0L) + + def diskUsedByRDD(rddId: Int) = + rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_+_).getOrElse(0L) def memRemaining : Long = maxMem - memUsed() + def rddBlocks = blocks.flatMap { + case (rdd: RDDBlockId, status) => Some(rdd, status) + case _ => None + } } case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, @@ -60,7 +64,7 @@ object StorageUtils { /* Returns RDD-level information, compiled from a list of StorageStatus objects */ def rddInfoFromStorageStatus(storageStatusList: Seq[StorageStatus], sc: SparkContext) : Array[RDDInfo] = { - rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc) + rddInfoFromBlockStatusList(storageStatusList.flatMap(_.rddBlocks).toMap[RDDBlockId, BlockStatus], sc) } /* Returns a map of blocks to their locations, compiled from a list of StorageStatus objects */ @@ -71,26 +75,21 @@ object StorageUtils { } /* Given a list of BlockStatus objets, returns information for each RDD */ - def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], + def rddInfoFromBlockStatusList(infos: Map[RDDBlockId, BlockStatus], sc: SparkContext) : Array[RDDInfo] = { // Group by rddId, ignore the partition name - val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) => - k.substring(0,k.lastIndexOf('_')) - }.mapValues(_.values.toArray) + val groupedRddBlocks = infos.groupBy { case(k, v) => k.rddId }.mapValues(_.values.toArray) // For each RDD, generate an RDDInfo object - val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) => + val rddInfos = groupedRddBlocks.map { case (rddId, rddBlocks) => // Add up memory and disk sizes val memSize = rddBlocks.map(_.memSize).reduce(_ + _) val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _) - // Find the id of the RDD, e.g. rdd_1 => 1 - val rddId = rddKey.split("_").last.toInt - // Get the friendly name and storage level for the RDD, if available sc.persistentRdds.get(rddId).map { r => - val rddName = Option(r.name).getOrElse(rddKey) + val rddName = Option(r.name).getOrElse(rddId.toString) val rddStorageLevel = r.getStorageLevel RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize) } @@ -101,16 +100,14 @@ object StorageUtils { rddInfos } - /* Removes all BlockStatus object that are not part of a block prefix */ - def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus], - prefix: String) : Array[StorageStatus] = { + /* Filters storage status by a given RDD id. */ + def filterStorageStatusByRDD(storageStatusList: Array[StorageStatus], rddId: Int) + : Array[StorageStatus] = { storageStatusList.map { status => - val newBlocks = status.blocks.filterKeys(_.startsWith(prefix)) + val newBlocks = status.rddBlocks.filterKeys(_.rddId == rddId).toMap[BlockId, BlockStatus] //val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _) StorageStatus(status.blockManagerId, status.maxMem, newBlocks) } - } - } diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index f2ae8dd97d..860e680576 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -36,11 +36,11 @@ private[spark] object ThreadingTest { val numBlocksPerProducer = 20000 private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread { - val queue = new ArrayBlockingQueue[(String, Seq[Int])](100) + val queue = new ArrayBlockingQueue[(BlockId, Seq[Int])](100) override def run() { for (i <- 1 to numBlocksPerProducer) { - val blockId = "b-" + id + "-" + i + val blockId = TestBlockId("b-" + id + "-" + i) val blockSize = Random.nextInt(1000) val block = (1 to blockSize).map(_ => Random.nextInt()) val level = randomLevel() @@ -64,7 +64,7 @@ private[spark] object ThreadingTest { private[spark] class ConsumerThread( manager: BlockManager, - queue: ArrayBlockingQueue[(String, Seq[Int])] + queue: ArrayBlockingQueue[(BlockId, Seq[Int])] ) extends Thread { var numBlockConsumed = 0 diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 43c1257677..b83cd54f3c 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.storage.{StorageStatus, StorageUtils} +import org.apache.spark.storage.{BlockId, StorageStatus, StorageUtils} import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus import org.apache.spark.ui.UIUtils._ import org.apache.spark.ui.Page._ @@ -33,21 +33,20 @@ private[spark] class RDDPage(parent: BlockManagerUI) { val sc = parent.sc def render(request: HttpServletRequest): Seq[Node] = { - val id = request.getParameter("id") - val prefix = "rdd_" + id.toString + val id = request.getParameter("id").toInt val storageStatusList = sc.getExecutorStorageStatus - val filteredStorageStatusList = StorageUtils. - filterStorageStatusByPrefix(storageStatusList, prefix) + val filteredStorageStatusList = StorageUtils.filterStorageStatusByRDD(storageStatusList, id) val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage") - val workers = filteredStorageStatusList.map((prefix, _)) + val workers = filteredStorageStatusList.map((id, _)) val workerTable = listingTable(workerHeaders, workerRow, workers) val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk", "Executors") - val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1) + val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray. + sortWith(_._1.name < _._1.name) val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList) val blocks = blockStatuses.map { case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN"))) @@ -99,7 +98,7 @@ private[spark] class RDDPage(parent: BlockManagerUI) { headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage) } - def blockRow(row: (String, BlockStatus, Seq[String])): Seq[Node] = { + def blockRow(row: (BlockId, BlockStatus, Seq[String])): Seq[Node] = { val (id, block, locations) = row <tr> <td>{id}</td> @@ -118,15 +117,15 @@ private[spark] class RDDPage(parent: BlockManagerUI) { </tr> } - def workerRow(worker: (String, StorageStatus)): Seq[Node] = { - val (prefix, status) = worker + def workerRow(worker: (Int, StorageStatus)): Seq[Node] = { + val (rddId, status) = worker <tr> <td>{status.blockManagerId.host + ":" + status.blockManagerId.port}</td> <td> - {Utils.bytesToString(status.memUsed(prefix))} + {Utils.bytesToString(status.memUsedByRDD(rddId))} ({Utils.bytesToString(status.memRemaining)} Remaining) </td> - <td>{Utils.bytesToString(status.diskUsed(prefix))}</td> + <td>{Utils.bytesToString(status.diskUsedByRDD(rddId))}</td> </tr> } } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 3a7171c488..ced036c58d 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.mock.EasyMockSugar import org.apache.spark.rdd.RDD -import org.apache.spark.storage.{BlockManager, StorageLevel} +import org.apache.spark.storage.{BlockManager, RDDBlockId, StorageLevel} // TODO: Test the CacheManager's thread-safety aspects class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar { @@ -52,9 +52,9 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar test("get uncached rdd") { expecting { - blockManager.get("rdd_0_0").andReturn(None) - blockManager.put("rdd_0_0", ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY, true). - andReturn(0) + blockManager.get(RDDBlockId(0, 0)).andReturn(None) + blockManager.put(RDDBlockId(0, 0), ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY, + true).andReturn(0) } whenExecuting(blockManager) { @@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar test("get cached rdd") { expecting { - blockManager.get("rdd_0_0").andReturn(Some(ArrayBuffer(5, 6, 7).iterator)) + blockManager.get(RDDBlockId(0, 0)).andReturn(Some(ArrayBuffer(5, 6, 7).iterator)) } whenExecuting(blockManager) { @@ -79,7 +79,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar test("get uncached local rdd") { expecting { // Local computation should not persist the resulting value, so don't expect a put(). - blockManager.get("rdd_0_0").andReturn(None) + blockManager.get(RDDBlockId(0, 0)).andReturn(None) } whenExecuting(blockManager) { diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index d9103aebb7..7ca5f16202 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.FunSuite import java.io.File import org.apache.spark.rdd._ import org.apache.spark.SparkContext._ -import storage.StorageLevel +import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { @@ -83,7 +83,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } test("BlockRDD") { - val blockId = "id" + val blockId = TestBlockId("id") val blockManager = SparkEnv.get.blockManager blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) val blockRDD = new BlockRDD[String](sc, Array(blockId)) @@ -191,7 +191,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } test("CheckpointRDD with zero partitions") { - val rdd = new BlockRDD[Int](sc, Array[String]()) + val rdd = new BlockRDD[Int](sc, Array[BlockId]()) assert(rdd.partitions.size === 0) assert(rdd.isCheckpointed === false) rdd.checkpoint() diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index cd2bf9a8ff..480bac84f3 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -18,24 +18,14 @@ package org.apache.spark import network.ConnectionManagerId -import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Timeouts._ +import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers -import org.scalatest.prop.Checkers import org.scalatest.time.{Span, Millis} -import org.scalacheck.Arbitrary._ -import org.scalacheck.Gen -import org.scalacheck.Prop._ -import org.eclipse.jetty.server.{Server, Request, Handler} - -import com.google.common.io.Files - -import scala.collection.mutable.ArrayBuffer import SparkContext._ -import storage.{GetBlock, BlockManagerWorker, StorageLevel} -import ui.JettyUtils +import org.apache.spark.storage.{BlockManagerWorker, GetBlock, RDDBlockId, StorageLevel} class NotSerializableClass @@ -193,7 +183,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter // Get all the locations of the first partition and try to fetch the partitions // from those locations. - val blockIds = data.partitions.indices.map(index => "rdd_%d_%d".format(data.id, index)).toArray + val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray val blockId = blockIds(0) val blockManager = SparkEnv.get.blockManager blockManager.master.getLocations(blockId).foreach(id => { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2f933246b0..3952ee9264 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.Partition import org.apache.spark.TaskContext import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency} import org.apache.spark.{FetchFailed, Success, TaskEndReason} -import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster} +import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.scheduler.SchedulingMode.SchedulingMode @@ -75,15 +75,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] // stub out BlockManagerMaster.getLocations to use our cacheLocations val blockManagerMaster = new BlockManagerMaster(null) { - override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { - blockIds.map { name => - val pieces = name.split("_") - if (pieces(0) == "rdd") { - val key = pieces(1).toInt -> pieces(2).toInt - cacheLocations.getOrElse(key, Seq()) - } else { - Seq() - } + override def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { + blockIds.map { + _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)). + getOrElse(Seq()) }.toSeq } override def removeExecutor(execId: String) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala index 119ba30090..ee150a3107 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala @@ -23,6 +23,7 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv} import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult} +import org.apache.spark.storage.TaskResultBlockId /** * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. @@ -85,7 +86,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) assert(result === 1.to(akkaFrameSize).toArray) - val RESULT_BLOCK_ID = "taskresult_0" + val RESULT_BLOCK_ID = TaskResultBlockId(0) assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0, "Expect result to be removed from the block manager.") } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala new file mode 100644 index 0000000000..cb76275e39 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.scalatest.FunSuite + +class BlockIdSuite extends FunSuite { + def assertSame(id1: BlockId, id2: BlockId) { + assert(id1.name === id2.name) + assert(id1.hashCode === id2.hashCode) + assert(id1 === id2) + } + + def assertDifferent(id1: BlockId, id2: BlockId) { + assert(id1.name != id2.name) + assert(id1.hashCode != id2.hashCode) + assert(id1 != id2) + } + + test("test-bad-deserialization") { + try { + // Try to deserialize an invalid block id. + BlockId("myblock") + fail() + } catch { + case e: IllegalStateException => // OK + case _ => fail() + } + } + + test("rdd") { + val id = RDDBlockId(1, 2) + assertSame(id, RDDBlockId(1, 2)) + assertDifferent(id, RDDBlockId(1, 1)) + assert(id.name === "rdd_1_2") + assert(id.asRDDId.get.rddId === 1) + assert(id.asRDDId.get.splitIndex === 2) + assert(id.isRDD) + assertSame(id, BlockId(id.toString)) + } + + test("shuffle") { + val id = ShuffleBlockId(1, 2, 3) + assertSame(id, ShuffleBlockId(1, 2, 3)) + assertDifferent(id, ShuffleBlockId(3, 2, 3)) + assert(id.name === "shuffle_1_2_3") + assert(id.asRDDId === None) + assert(id.shuffleId === 1) + assert(id.mapId === 2) + assert(id.reduceId === 3) + assert(id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + + test("broadcast") { + val id = BroadcastBlockId(42) + assertSame(id, BroadcastBlockId(42)) + assertDifferent(id, BroadcastBlockId(123)) + assert(id.name === "broadcast_42") + assert(id.asRDDId === None) + assert(id.broadcastId === 42) + assert(id.isBroadcast) + assertSame(id, BlockId(id.toString)) + } + + test("taskresult") { + val id = TaskResultBlockId(60) + assertSame(id, TaskResultBlockId(60)) + assertDifferent(id, TaskResultBlockId(61)) + assert(id.name === "taskresult_60") + assert(id.asRDDId === None) + assert(id.taskId === 60) + assert(!id.isRDD) + assertSame(id, BlockId(id.toString)) + } + + test("stream") { + val id = StreamBlockId(1, 100) + assertSame(id, StreamBlockId(1, 100)) + assertDifferent(id, StreamBlockId(2, 101)) + assert(id.name === "input-1-100") + assert(id.asRDDId === None) + assert(id.streamId === 1) + assert(id.uniqueId === 100) + assert(!id.isBroadcast) + assertSame(id, BlockId(id.toString)) + } + + test("test") { + val id = TestBlockId("abc") + assertSame(id, TestBlockId("abc")) + assertDifferent(id, TestBlockId("ab")) + assert(id.name === "test_abc") + assert(id.asRDDId === None) + assert(id.id === "abc") + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 038a9acb85..484a654108 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -32,7 +32,6 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.util.{SizeEstimator, Utils, AkkaUtils, ByteBufferInputStream} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} - class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { var store: BlockManager = null var store2: BlockManager = null @@ -46,6 +45,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT System.setProperty("spark.kryoserializer.buffer.mb", "1") val serializer = new KryoSerializer + // Implicitly convert strings to BlockIds for test clarity. + implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) + before { val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) this.actorSystem = actorSystem @@ -229,31 +232,31 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) // Putting a1, a2 and a3 in memory. - store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY) store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY) master.removeRdd(0, blocking = false) eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("rdd_0_0") should be (None) - master.getLocations("rdd_0_0") should have size 0 + store.getSingle(rdd(0, 0)) should be (None) + master.getLocations(rdd(0, 0)) should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("rdd_0_1") should be (None) - master.getLocations("rdd_0_1") should have size 0 + store.getSingle(rdd(0, 1)) should be (None) + master.getLocations(rdd(0, 1)) should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { store.getSingle("nonrddblock") should not be (None) master.getLocations("nonrddblock") should have size (1) } - store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY) master.removeRdd(0, blocking = true) - store.getSingle("rdd_0_0") should be (None) - master.getLocations("rdd_0_0") should have size 0 - store.getSingle("rdd_0_1") should be (None) - master.getLocations("rdd_0_1") should have size 0 + store.getSingle(rdd(0, 0)) should be (None) + master.getLocations(rdd(0, 0)) should have size 0 + store.getSingle(rdd(0, 1)) should be (None) + master.getLocations(rdd(0, 1)) should have size 0 } test("reregistration on heart beat") { @@ -372,41 +375,41 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) - store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 1), a1, StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 2), a2, StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 3), a3, StorageLevel.MEMORY_ONLY) // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2 // from the same RDD - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") - assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") - assert(store.getSingle("rdd_0_1") != None, "rdd_0_1 was not in store") + assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingle(rdd(0, 2)) != None, "rdd_0_2 was not in store") + assert(store.getSingle(rdd(0, 1)) != None, "rdd_0_1 was not in store") // Check that rdd_0_3 doesn't replace them even after further accesses - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") + assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") } test("in-memory LRU for partitions of multiple RDDs") { store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) - store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) // At this point rdd_1_1 should've replaced rdd_0_1 - assert(store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was not in store") - assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store") - assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store") + assert(store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was not in store") + assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store") + assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store") // Do a get() on rdd_0_2 so that it is the most recently used item - assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") + assert(store.getSingle(rdd(0, 2)) != None, "rdd_0_2 was not in store") // Put in more partitions from RDD 0; they should replace rdd_1_1 - store.putSingle("rdd_0_3", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_4", new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 3), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 4), new Array[Byte](400), StorageLevel.MEMORY_ONLY) // Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped // when we try to add rdd_0_4. - assert(!store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was in store") - assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store") - assert(!store.memoryStore.contains("rdd_0_4"), "rdd_0_4 was in store") - assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store") - assert(store.memoryStore.contains("rdd_0_3"), "rdd_0_3 was not in store") + assert(!store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was in store") + assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store") + assert(!store.memoryStore.contains(rdd(0, 4)), "rdd_0_4 was in store") + assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store") + assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store") } test("on-disk storage") { @@ -590,43 +593,46 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT try { System.setProperty("spark.shuffle.compress", "true") store = new BlockManager("exec1", actorSystem, master, serializer, 2000) - store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed") + store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, + "shuffle_0_0_0 was not compressed") store.stop() store = null System.setProperty("spark.shuffle.compress", "false") store = new BlockManager("exec2", actorSystem, master, serializer, 2000) - store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed") + store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000, + "shuffle_0_0_0 was compressed") store.stop() store = null System.setProperty("spark.broadcast.compress", "true") store = new BlockManager("exec3", actorSystem, master, serializer, 2000) - store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed") + store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100, + "broadcast_0 was not compressed") store.stop() store = null System.setProperty("spark.broadcast.compress", "false") store = new BlockManager("exec4", actorSystem, master, serializer, 2000) - store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed") + store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed") store.stop() store = null System.setProperty("spark.rdd.compress", "true") store = new BlockManager("exec5", actorSystem, master, serializer, 2000) - store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed") + store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed") store.stop() store = null System.setProperty("spark.rdd.compress", "false") store = new BlockManager("exec6", actorSystem, master, serializer, 2000) - store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed") + store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed") store.stop() store = null |