aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala')
-rw-r--r--core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala128
1 files changed, 55 insertions, 73 deletions
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index 17c64455b2..978a6ded80 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -17,10 +17,12 @@
package org.apache.spark.broadcast
-import org.apache.spark.storage.{BroadcastBlockId, _}
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
import org.scalatest.FunSuite
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
+import org.apache.spark.storage._
+
+
class BroadcastSuite extends FunSuite with LocalSparkContext {
private val httpConf = broadcastConf("HttpBroadcastFactory")
@@ -124,12 +126,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
val numSlaves = if (distributed) 2 else 0
- def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id))
-
// Verify that the broadcast file is created, and blocks are persisted only on the driver
- def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) {
+ val blockId = BroadcastBlockId(broadcastId)
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === 1)
statuses.head match { case (bm, status) =>
assert(bm.executorId === "<driver>", "Block should only be on the driver")
@@ -139,14 +139,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
}
if (distributed) {
// this file is only generated in distributed mode
- assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!")
+ assert(HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file not found!")
}
}
// Verify that blocks are persisted in both the executors and the driver
- def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
+ val blockId = BroadcastBlockId(broadcastId)
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === numSlaves + 1)
statuses.foreach { case (_, status) =>
assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
@@ -157,21 +157,21 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
// is true. In the latter case, also verify that the broadcast file is deleted on the driver.
- def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) {
+ val blockId = BroadcastBlockId(broadcastId)
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
val expectedNumBlocks = if (removeFromDriver) 0 else 1
val possiblyNot = if (removeFromDriver) "" else " not"
assert(statuses.size === expectedNumBlocks,
"Block should%s be unpersisted on the driver".format(possiblyNot))
if (distributed && removeFromDriver) {
// this file is only generated in distributed mode
- assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists,
+ assert(!HttpBroadcast.getFile(blockId.broadcastId).exists,
"Broadcast file should%s be deleted".format(possiblyNot))
}
}
- testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation,
+ testUnpersistBroadcast(distributed, numSlaves, httpConf, afterCreation,
afterUsingBroadcast, afterUnpersist, removeFromDriver)
}
@@ -185,67 +185,51 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
val numSlaves = if (distributed) 2 else 0
- def getBlockIds(id: Long) = {
- val broadcastBlockId = BroadcastBlockId(id)
- val metaBlockId = BroadcastBlockId(id, "meta")
- // Assume broadcast value is small enough to fit into 1 piece
- val pieceBlockId = BroadcastBlockId(id, "piece0")
- if (distributed) {
- // the metadata and piece blocks are generated only in distributed mode
- Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId)
- } else {
- Seq[BroadcastBlockId](broadcastBlockId)
- }
+ // Verify that blocks are persisted only on the driver
+ def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) {
+ var blockId = BroadcastBlockId(broadcastId)
+ var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === 1)
+
+ blockId = BroadcastBlockId(broadcastId, "piece0")
+ statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === (if (distributed) 1 else 0))
}
- // Verify that blocks are persisted only on the driver
- def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ // Verify that blocks are persisted in both the executors and the driver
+ def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
+ var blockId = BroadcastBlockId(broadcastId)
+ var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ if (distributed) {
+ assert(statuses.size === numSlaves + 1)
+ } else {
assert(statuses.size === 1)
- statuses.head match { case (bm, status) =>
- assert(bm.executorId === "<driver>", "Block should only be on the driver")
- assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
- assert(status.memSize > 0, "Block should be in memory store on the driver")
- assert(status.diskSize === 0, "Block should not be in disk store on the driver")
- }
}
- }
- // Verify that blocks are persisted in both the executors and the driver
- def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- if (blockId.field == "meta") {
- // Meta data is only on the driver
- assert(statuses.size === 1)
- statuses.head match { case (bm, _) => assert(bm.executorId === "<driver>") }
- } else {
- // Other blocks are on both the executors and the driver
- assert(statuses.size === numSlaves + 1,
- blockId + " has " + statuses.size + " statuses: " + statuses.mkString(","))
- statuses.foreach { case (_, status) =>
- assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
- assert(status.memSize > 0, "Block should be in memory store")
- assert(status.diskSize === 0, "Block should not be in disk store")
- }
- }
+ blockId = BroadcastBlockId(broadcastId, "piece0")
+ statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ if (distributed) {
+ assert(statuses.size === numSlaves + 1)
+ } else {
+ assert(statuses.size === 0)
}
}
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
// is true.
- def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- val expectedNumBlocks = if (removeFromDriver) 0 else 1
- val possiblyNot = if (removeFromDriver) "" else " not"
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- assert(statuses.size === expectedNumBlocks,
- "Block should%s be unpersisted on the driver".format(possiblyNot))
- }
+ def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) {
+ var blockId = BroadcastBlockId(broadcastId)
+ var expectedNumBlocks = if (removeFromDriver) 0 else 1
+ var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === expectedNumBlocks)
+
+ blockId = BroadcastBlockId(broadcastId, "piece0")
+ expectedNumBlocks = if (removeFromDriver || !distributed) 0 else 1
+ statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === expectedNumBlocks)
}
- testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation,
+ testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation,
afterUsingBroadcast, afterUnpersist, removeFromDriver)
}
@@ -262,10 +246,9 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
distributed: Boolean,
numSlaves: Int, // used only when distributed = true
broadcastConf: SparkConf,
- getBlockIds: Long => Seq[BroadcastBlockId],
- afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
- afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
- afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+ afterCreation: (Long, BlockManagerMaster) => Unit,
+ afterUsingBroadcast: (Long, BlockManagerMaster) => Unit,
+ afterUnpersist: (Long, BlockManagerMaster) => Unit,
removeFromDriver: Boolean) {
sc = if (distributed) {
@@ -278,15 +261,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
// Create broadcast variable
val broadcast = sc.broadcast(list)
- val blocks = getBlockIds(broadcast.id)
- afterCreation(blocks, blockManagerMaster)
+ afterCreation(broadcast.id, blockManagerMaster)
// Use broadcast variable on all executors
val partitions = 10
assert(partitions > numSlaves)
val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
- afterUsingBroadcast(blocks, blockManagerMaster)
+ afterUsingBroadcast(broadcast.id, blockManagerMaster)
// Unpersist broadcast
if (removeFromDriver) {
@@ -294,7 +276,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
} else {
broadcast.unpersist(blocking = true)
}
- afterUnpersist(blocks, blockManagerMaster)
+ afterUnpersist(broadcast.id, blockManagerMaster)
// If the broadcast is removed from driver, all subsequent uses of the broadcast variable
// should throw SparkExceptions. Otherwise, the result should be the same as before.