diff options
10 files changed, 144 insertions, 20 deletions
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java index a5ec5e2fb1..299162f12c 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java @@ -36,7 +36,7 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> { @Override public void messageReceived(ChannelHandlerContext ctx, String blockIdString) { - BlockId blockId = BlockId.fromString(blockIdString); + BlockId blockId = BlockId.apply(blockIdString); String path = pResolver.getAbsolutePath(blockId.filename()); // if getFilePath returns null, close the channel if (path == null) { diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala index d8cd2355c1..9ade373823 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala @@ -58,7 +58,7 @@ private[spark] object FileHeader { for (i <- 1 to idLength) { idBuilder += buf.readByte().asInstanceOf[Char] } - val blockId = BlockId.fromString(idBuilder.toString()) + val blockId = BlockId(idBuilder.toString()) new FileHeader(length, blockId) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala index de3cab4e78..bb1c8ac20e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala @@ -100,7 +100,7 @@ private[spark] object ShuffleCopier extends Logging { } val host = args(0) val port = args(1).toInt - val blockId = BlockId.fromString(args(2)) + val blockId = BlockId(args(2)) val threads = if (args.length > 3) args(3).toInt else 10 val copiers = Executors.newFixedThreadPool(80) diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala index b6062bc347..611a44e5b9 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala @@ -55,7 +55,7 @@ private[spark] object ShuffleSender { val pResovler = new PathResolver { override def getAbsolutePath(blockIdString: String): String = { - val blockId = BlockId.fromString(blockIdString) + val blockId = BlockId(blockIdString) if (!blockId.isShuffle) { throw new Exception("Block " + blockId + " is not a shuffle block") } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 62b4e37db6..0fc212d8f9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -42,28 +42,29 @@ private[spark] abstract class BlockId { } } -case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { +private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { def filename = "rdd_" + rddId + "_" + splitIndex } +private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { def filename = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } -case class BroadcastBlockId(broadcastId: Long) extends BlockId { +private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { def filename = "broadcast_" + broadcastId } -case class TaskResultBlockId(taskId: Long) extends BlockId { +private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId { def filename = "taskresult_" + taskId } -case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId { +private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId { def filename = "input-" + streamId + "-" + uniqueId } // Intended only for testing purposes -case class TestBlockId(id: String) extends BlockId { +private[spark] case class TestBlockId(id: String) extends BlockId { def filename = "test_" + id } @@ -76,7 +77,7 @@ private[spark] object BlockId { val StreamInput = "input-([0-9]+)-([0-9]+)".r val Test = "test_(.*)".r - def fromString(id: String) = id match { + def apply(id: String) = id match { case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) case Shuffle(shuffleId, mapId, reduceId) => ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index b0832fd28e..cc0c46ec16 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -70,7 +70,7 @@ private[storage] object BlockManagerMessages { override def readExternal(in: ObjectInput) { blockManagerId = BlockManagerId(in) - blockId = BlockId.fromString(in.readUTF()) + blockId = BlockId(in.readUTF()) storageLevel = StorageLevel(in) memSize = in.readLong() diskSize = in.readLong() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala index 8ccda83890..7e8ee2486e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala @@ -74,7 +74,7 @@ private[spark] class BlockMessage() { for (i <- 1 to idLength) { idBuilder += buffer.getChar() } - id = BlockId.fromString(idBuilder.toString) + id = BlockId(idBuilder.toString) if (typ == BlockMessage.TYPE_PUT_BLOCK) { @@ -117,7 +117,7 @@ private[spark] class BlockMessage() { def toBufferMessage: BufferMessage = { val startTime = System.currentTimeMillis val buffers = new ArrayBuffer[ByteBuffer]() - var buffer = ByteBuffer.allocate(4 + 4 + id.filename.length * 2) // TODO: Why x2? + var buffer = ByteBuffer.allocate(4 + 4 + id.filename.length * 2) buffer.putInt(typ).putInt(id.filename.length) id.filename.foreach((x: Char) => buffer.putChar(x)) buffer.flip() @@ -201,8 +201,8 @@ private[spark] object BlockMessage { def main(args: Array[String]) { val B = new BlockMessage() - B.set(new PutBlock( - new TestBlockId("ABC"), ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) + val blockId = TestBlockId("ABC") + B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) val bMsg = B.toBufferMessage val C = new BlockMessage() C.set(bMsg) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 9e3ee87654..4200935d93 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -316,7 +316,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private[storage] def startShuffleBlockSender(port: Int): Int = { val pResolver = new PathResolver { override def getAbsolutePath(blockIdString: String): String = { - val blockId = BlockId.fromString(blockIdString) + val blockId = BlockId(blockIdString) if (!blockId.isShuffle) null else DiskStore.this.getFile(blockId).getAbsolutePath } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 8358596861..1720007e4e 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -25,15 +25,15 @@ private[spark] case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, blocks: Map[BlockId, BlockStatus]) { - def memUsed() = blocks.values.map(_.memSize).reduceOption(_+_).getOrElse(0l) + def memUsed() = blocks.values.map(_.memSize).reduceOption(_+_).getOrElse(0L) def memUsedByRDD(rddId: Int) = - rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_+_).getOrElse(0l) + rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_+_).getOrElse(0L) - def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_+_).getOrElse(0l) + def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_+_).getOrElse(0L) def diskUsedByRDD(rddId: Int) = - rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_+_).getOrElse(0l) + rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_+_).getOrElse(0L) def memRemaining : Long = maxMem - memUsed() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala new file mode 100644 index 0000000000..538482f6ff --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.scalatest.FunSuite + +class BlockIdSuite extends FunSuite { + def assertSame(id1: BlockId, id2: BlockId) { + assert(id1.filename === id2.filename) + assert(id1.toString === id2.toString) + assert(id1.hashCode === id2.hashCode) + assert(id1 === id2) + } + + def assertDifferent(id1: BlockId, id2: BlockId) { + assert(id1.filename != id2.filename) + assert(id1.toString != id2.toString) + assert(id1.hashCode != id2.hashCode) + assert(id1 != id2) + } + + test("basic-functions") { + case class MyBlockId(filename: String) extends BlockId + + val id = MyBlockId("a") + assertSame(id, MyBlockId("a")) + assertDifferent(id, MyBlockId("b")) + assert(id.asRDDId === None) + + try { + // Try to deserialize an invalid block id. + BlockId("a") + fail() + } catch { + case e: IllegalStateException => // OK + case _ => fail() + } + } + + test("rdd") { + val id = RDDBlockId(1, 2) + assertSame(id, RDDBlockId(1, 2)) + assertDifferent(id, RDDBlockId(1, 1)) + assert(id.toString === "rdd_1_2") + assert(id.asRDDId.get.rddId === 1) + assert(id.asRDDId.get.splitIndex === 2) + assert(id.isRDD) + assertSame(id, BlockId(id.toString)) + } + + test("shuffle") { + val id = ShuffleBlockId(1, 2, 3) + assertSame(id, ShuffleBlockId(1, 2, 3)) + assertDifferent(id, ShuffleBlockId(3, 2, 3)) + assert(id.toString === "shuffle_1_2_3") + assert(id.asRDDId === None) + assert(id.shuffleId === 1) + assert(id.mapId === 2) + assert(id.reduceId === 3) + assert(id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + + test("broadcast") { + val id = BroadcastBlockId(42) + assertSame(id, BroadcastBlockId(42)) + assertDifferent(id, BroadcastBlockId(123)) + assert(id.toString === "broadcast_42") + assert(id.asRDDId === None) + assert(id.broadcastId === 42) + assert(id.isBroadcast) + assertSame(id, BlockId(id.toString)) + } + + test("taskresult") { + val id = TaskResultBlockId(60) + assertSame(id, TaskResultBlockId(60)) + assertDifferent(id, TaskResultBlockId(61)) + assert(id.toString === "taskresult_60") + assert(id.asRDDId === None) + assert(id.taskId === 60) + assert(!id.isRDD) + assertSame(id, BlockId(id.toString)) + } + + test("stream") { + val id = StreamBlockId(1, 100) + assertSame(id, StreamBlockId(1, 100)) + assertDifferent(id, StreamBlockId(2, 101)) + assert(id.toString === "input-1-100") + assert(id.asRDDId === None) + assert(id.streamId === 1) + assert(id.uniqueId === 100) + assert(!id.isBroadcast) + assertSame(id, BlockId(id.toString)) + } + + test("test") { + val id = TestBlockId("abc") + assertSame(id, TestBlockId("abc")) + assertDifferent(id, TestBlockId("ab")) + assert(id.toString === "test_abc") + assert(id.asRDDId === None) + assert(id.id === "abc") + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } +} |