aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala282
-rw-r--r--core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala128
3 files changed, 181 insertions, 240 deletions
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index a8c827030a..6a187b4062 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -32,8 +32,19 @@ import org.apache.spark.annotation.DeveloperApi
*/
@DeveloperApi
trait BroadcastFactory {
+
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
+
+ /**
+ * Creates a new broadcast variable.
+ *
+ * @param value value to broadcast
+ * @param isLocal whether we are in local mode (single JVM process)
+ * @param id unique id representing this broadcast variable
+ */
def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]
+
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
+
def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index d8be649f96..6173fd3a69 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -18,50 +18,116 @@
package org.apache.spark.broadcast
import java.io._
+import java.nio.ByteBuffer
+import scala.collection.JavaConversions.asJavaEnumeration
import scala.reflect.ClassTag
import scala.util.Random
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
+import org.apache.spark.util.ByteBufferInputStream
/**
- * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
- * protocol to do a distributed transfer of the broadcasted data to the executors.
- * The mechanism is as follows. The driver divides the serializes the broadcasted data,
- * divides it into smaller chunks, and stores them in the BlockManager of the driver.
- * These chunks are reported to the BlockManagerMaster so that all the executors can
- * learn the location of those chunks. The first time the broadcast variable (sent as
- * part of task) is deserialized at a executor, all the chunks are fetched using
- * the BlockManager. When all the chunks are fetched (initially from the driver's
- * BlockManager), they are combined and deserialized to recreate the broadcasted data.
- * However, the chunks are also stored in the BlockManager and reported to the
- * BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns
- * multiple locations for each chunk. Hence, subsequent fetches of each chunk will be
- * made to other executors who already have those chunks, resulting in a distributed
- * fetching. This prevents the driver from being the bottleneck in sending out multiple
- * copies of the broadcast data (one per executor) as done by the
- * [[org.apache.spark.broadcast.HttpBroadcast]].
+ * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
+ *
+ * The mechanism is as follows:
+ *
+ * The driver divides the serialized object into small chunks and
+ * stores those chunks in the BlockManager of the driver.
+ *
+ * On each executor, the executor first attempts to fetch the object from its BlockManager. If
+ * it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or
+ * other executors if available. Once it gets the chunks, it puts the chunks in its own
+ * BlockManager, ready for other executors to fetch from.
+ *
+ * This prevents the driver from being the bottleneck in sending out multiple copies of the
+ * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]].
+ *
+ * @param obj object to broadcast
+ * @param isLocal whether Spark is running in local mode (single JVM process).
+ * @param id A unique identifier for the broadcast variable.
*/
private[spark] class TorrentBroadcast[T: ClassTag](
- @transient var value_ : T, isLocal: Boolean, id: Long)
+ obj : T,
+ @transient private val isLocal: Boolean,
+ id: Long)
extends Broadcast[T](id) with Logging with Serializable {
- override protected def getValue() = value_
+ /**
+ * Value of the broadcast object. On driver, this is set directly by the constructor.
+ * On executors, this is reconstructed by [[readObject]], which builds this value by reading
+ * blocks from the driver and/or other executors.
+ */
+ @transient private var _value: T = obj
private val broadcastId = BroadcastBlockId(id)
- SparkEnv.get.blockManager.putSingle(
- broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+ /** Total number of blocks this broadcast variable contains. */
+ private val numBlocks: Int = writeBlocks()
+
+ override protected def getValue() = _value
+
+ /**
+ * Divide the object into multiple blocks and put those blocks in the block manager.
+ *
+ * @return number of blocks this broadcast variable is divided into
+ */
+ private def writeBlocks(): Int = {
+ // For local mode, just put the object in the BlockManager so we can find it later.
+ SparkEnv.get.blockManager.putSingle(
+ broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+
+ if (!isLocal) {
+ val blocks = TorrentBroadcast.blockifyObject(_value)
+ blocks.zipWithIndex.foreach { case (block, i) =>
+ SparkEnv.get.blockManager.putBytes(
+ BroadcastBlockId(id, "piece" + i),
+ block,
+ StorageLevel.MEMORY_AND_DISK_SER,
+ tellMaster = true)
+ }
+ blocks.length
+ } else {
+ 0
+ }
+ }
+
+ /** Fetch torrent blocks from the driver and/or other executors. */
+ private def readBlocks(): Array[ByteBuffer] = {
+ // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
+ // to the driver, so other executors can pull these chunks from this executor as well.
+ val blocks = new Array[ByteBuffer](numBlocks)
+ val bm = SparkEnv.get.blockManager
- @transient private var arrayOfBlocks: Array[TorrentBlock] = null
- @transient private var totalBlocks = -1
- @transient private var totalBytes = -1
- @transient private var hasBlocks = 0
+ for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
+ val pieceId = BroadcastBlockId(id, "piece" + pid)
- if (!isLocal) {
- sendBroadcast()
+ // First try getLocalBytes because there is a chance that previous attempts to fetch the
+ // broadcast blocks have already fetched some of the blocks. In that case, some blocks
+ // would be available locally (on this executor).
+ var blockOpt = bm.getLocalBytes(pieceId)
+ if (!blockOpt.isDefined) {
+ blockOpt = bm.getRemoteBytes(pieceId)
+ blockOpt match {
+ case Some(block) =>
+ // If we found the block from remote executors/driver's BlockManager, put the block
+ // in this executor's BlockManager.
+ SparkEnv.get.blockManager.putBytes(
+ pieceId,
+ block,
+ StorageLevel.MEMORY_AND_DISK_SER,
+ tellMaster = true)
+
+ case None =>
+ throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
+ }
+ }
+ // If we get here, the option is defined.
+ blocks(pid) = blockOpt.get
+ }
+ blocks
}
/**
@@ -79,26 +145,6 @@ private[spark] class TorrentBroadcast[T: ClassTag](
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
}
- private def sendBroadcast() {
- val tInfo = TorrentBroadcast.blockifyObject(value_)
- totalBlocks = tInfo.totalBlocks
- totalBytes = tInfo.totalBytes
- hasBlocks = tInfo.totalBlocks
-
- // Store meta-info
- val metaId = BroadcastBlockId(id, "meta")
- val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
- SparkEnv.get.blockManager.putSingle(
- metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
-
- // Store individual pieces
- for (i <- 0 until totalBlocks) {
- val pieceId = BroadcastBlockId(id, "piece" + i)
- SparkEnv.get.blockManager.putSingle(
- pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
- }
- }
-
/** Used by the JVM when serializing this object. */
private def writeObject(out: ObjectOutputStream) {
assertValid()
@@ -109,99 +155,30 @@ private[spark] class TorrentBroadcast[T: ClassTag](
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
TorrentBroadcast.synchronized {
- SparkEnv.get.blockManager.getSingle(broadcastId) match {
+ SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
case Some(x) =>
- value_ = x.asInstanceOf[T]
+ _value = x.asInstanceOf[T]
case None =>
- val start = System.nanoTime
logInfo("Started reading broadcast variable " + id)
-
- // Initialize @transient variables that will receive garbage values from the master.
- resetWorkerVariables()
-
- if (receiveBroadcast()) {
- value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
-
- /* Store the merged copy in cache so that the next worker doesn't need to rebuild it.
- * This creates a trade-off between memory usage and latency. Storing copy doubles
- * the memory footprint; not storing doubles deserialization cost. Also,
- * this does not need to be reported to BlockManagerMaster since other executors
- * does not need to access this block (they only need to fetch the chunks,
- * which are reported).
- */
- SparkEnv.get.blockManager.putSingle(
- broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
-
- // Remove arrayOfBlocks from memory once value_ is on local cache
- resetWorkerVariables()
- } else {
- logError("Reading broadcast variable " + id + " failed")
- }
-
- val time = (System.nanoTime - start) / 1e9
+ val start = System.nanoTime()
+ val blocks = readBlocks()
+ val time = (System.nanoTime() - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
- }
- }
- }
-
- private def resetWorkerVariables() {
- arrayOfBlocks = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = 0
- }
-
- private def receiveBroadcast(): Boolean = {
- // Receive meta-info about the size of broadcast data,
- // the number of chunks it is divided into, etc.
- val metaId = BroadcastBlockId(id, "meta")
- var attemptId = 10
- while (attemptId > 0 && totalBlocks == -1) {
- SparkEnv.get.blockManager.getSingle(metaId) match {
- case Some(x) =>
- val tInfo = x.asInstanceOf[TorrentInfo]
- totalBlocks = tInfo.totalBlocks
- totalBytes = tInfo.totalBytes
- arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
- hasBlocks = 0
- case None =>
- Thread.sleep(500)
- }
- attemptId -= 1
- }
-
- if (totalBlocks == -1) {
- return false
- }
-
- /*
- * Fetch actual chunks of data. Note that all these chunks are stored in
- * the BlockManager and reported to the master, so that other executors
- * can find out and pull the chunks from this executor.
- */
- val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
- for (pid <- recvOrder) {
- val pieceId = BroadcastBlockId(id, "piece" + pid)
- SparkEnv.get.blockManager.getSingle(pieceId) match {
- case Some(x) =>
- arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
- hasBlocks += 1
+ _value = TorrentBroadcast.unBlockifyObject[T](blocks)
+ // Store the merged copy in BlockManager so other tasks on this executor don't
+ // need to re-fetch it.
SparkEnv.get.blockManager.putSingle(
- pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
-
- case None =>
- throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
+ broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
}
}
-
- hasBlocks == totalBlocks
}
-
}
-private[broadcast] object TorrentBroadcast extends Logging {
+
+private object TorrentBroadcast extends Logging {
+ /** Size of each block. Default value is 4MB. */
private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
private var initialized = false
private var conf: SparkConf = null
@@ -223,7 +200,9 @@ private[broadcast] object TorrentBroadcast extends Logging {
initialized = false
}
- def blockifyObject[T: ClassTag](obj: T): TorrentInfo = {
+ def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = {
+ // TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks
+ // so we don't need to do the extra memory copy.
val bos = new ByteArrayOutputStream()
val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
val ser = SparkEnv.get.serializer.newInstance()
@@ -231,44 +210,27 @@ private[broadcast] object TorrentBroadcast extends Logging {
serOut.writeObject[T](obj).close()
val byteArray = bos.toByteArray
val bais = new ByteArrayInputStream(byteArray)
+ val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt
+ val blocks = new Array[ByteBuffer](numBlocks)
- var blockNum = byteArray.length / BLOCK_SIZE
- if (byteArray.length % BLOCK_SIZE != 0) {
- blockNum += 1
- }
-
- val blocks = new Array[TorrentBlock](blockNum)
var blockId = 0
-
for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
val tempByteArray = new Array[Byte](thisBlockSize)
bais.read(tempByteArray, 0, thisBlockSize)
- blocks(blockId) = new TorrentBlock(blockId, tempByteArray)
+ blocks(blockId) = ByteBuffer.wrap(tempByteArray)
blockId += 1
}
bais.close()
-
- val info = TorrentInfo(blocks, blockNum, byteArray.length)
- info.hasBlocks = blockNum
- info
+ blocks
}
- def unBlockifyObject[T: ClassTag](
- arrayOfBlocks: Array[TorrentBlock],
- totalBytes: Int,
- totalBlocks: Int): T = {
- val retByteArray = new Array[Byte](totalBytes)
- for (i <- 0 until totalBlocks) {
- System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
- i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
- }
+ def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = {
+ val is = new SequenceInputStream(
+ asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block))))
+ val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is
- val in: InputStream = {
- val arrIn = new ByteArrayInputStream(retByteArray)
- if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn
- }
val ser = SparkEnv.get.serializer.newInstance()
val serIn = ser.deserializeStream(in)
val obj = serIn.readObject[T]()
@@ -284,17 +246,3 @@ private[broadcast] object TorrentBroadcast extends Logging {
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
}
}
-
-private[broadcast] case class TorrentBlock(
- blockID: Int,
- byteArray: Array[Byte])
- extends Serializable
-
-private[broadcast] case class TorrentInfo(
- @transient arrayOfBlocks: Array[TorrentBlock],
- totalBlocks: Int,
- totalBytes: Int)
- extends Serializable {
-
- @transient var hasBlocks = 0
-}
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index 17c64455b2..978a6ded80 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -17,10 +17,12 @@
package org.apache.spark.broadcast
-import org.apache.spark.storage.{BroadcastBlockId, _}
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
import org.scalatest.FunSuite
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
+import org.apache.spark.storage._
+
+
class BroadcastSuite extends FunSuite with LocalSparkContext {
private val httpConf = broadcastConf("HttpBroadcastFactory")
@@ -124,12 +126,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
val numSlaves = if (distributed) 2 else 0
- def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id))
-
// Verify that the broadcast file is created, and blocks are persisted only on the driver
- def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) {
+ val blockId = BroadcastBlockId(broadcastId)
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === 1)
statuses.head match { case (bm, status) =>
assert(bm.executorId === "<driver>", "Block should only be on the driver")
@@ -139,14 +139,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
}
if (distributed) {
// this file is only generated in distributed mode
- assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!")
+ assert(HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file not found!")
}
}
// Verify that blocks are persisted in both the executors and the driver
- def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
+ val blockId = BroadcastBlockId(broadcastId)
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === numSlaves + 1)
statuses.foreach { case (_, status) =>
assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
@@ -157,21 +157,21 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
// is true. In the latter case, also verify that the broadcast file is deleted on the driver.
- def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) {
+ val blockId = BroadcastBlockId(broadcastId)
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
val expectedNumBlocks = if (removeFromDriver) 0 else 1
val possiblyNot = if (removeFromDriver) "" else " not"
assert(statuses.size === expectedNumBlocks,
"Block should%s be unpersisted on the driver".format(possiblyNot))
if (distributed && removeFromDriver) {
// this file is only generated in distributed mode
- assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists,
+ assert(!HttpBroadcast.getFile(blockId.broadcastId).exists,
"Broadcast file should%s be deleted".format(possiblyNot))
}
}
- testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation,
+ testUnpersistBroadcast(distributed, numSlaves, httpConf, afterCreation,
afterUsingBroadcast, afterUnpersist, removeFromDriver)
}
@@ -185,67 +185,51 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
val numSlaves = if (distributed) 2 else 0
- def getBlockIds(id: Long) = {
- val broadcastBlockId = BroadcastBlockId(id)
- val metaBlockId = BroadcastBlockId(id, "meta")
- // Assume broadcast value is small enough to fit into 1 piece
- val pieceBlockId = BroadcastBlockId(id, "piece0")
- if (distributed) {
- // the metadata and piece blocks are generated only in distributed mode
- Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId)
- } else {
- Seq[BroadcastBlockId](broadcastBlockId)
- }
+ // Verify that blocks are persisted only on the driver
+ def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) {
+ var blockId = BroadcastBlockId(broadcastId)
+ var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === 1)
+
+ blockId = BroadcastBlockId(broadcastId, "piece0")
+ statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === (if (distributed) 1 else 0))
}
- // Verify that blocks are persisted only on the driver
- def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ // Verify that blocks are persisted in both the executors and the driver
+ def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
+ var blockId = BroadcastBlockId(broadcastId)
+ var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ if (distributed) {
+ assert(statuses.size === numSlaves + 1)
+ } else {
assert(statuses.size === 1)
- statuses.head match { case (bm, status) =>
- assert(bm.executorId === "<driver>", "Block should only be on the driver")
- assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
- assert(status.memSize > 0, "Block should be in memory store on the driver")
- assert(status.diskSize === 0, "Block should not be in disk store on the driver")
- }
}
- }
- // Verify that blocks are persisted in both the executors and the driver
- def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- if (blockId.field == "meta") {
- // Meta data is only on the driver
- assert(statuses.size === 1)
- statuses.head match { case (bm, _) => assert(bm.executorId === "<driver>") }
- } else {
- // Other blocks are on both the executors and the driver
- assert(statuses.size === numSlaves + 1,
- blockId + " has " + statuses.size + " statuses: " + statuses.mkString(","))
- statuses.foreach { case (_, status) =>
- assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
- assert(status.memSize > 0, "Block should be in memory store")
- assert(status.diskSize === 0, "Block should not be in disk store")
- }
- }
+ blockId = BroadcastBlockId(broadcastId, "piece0")
+ statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ if (distributed) {
+ assert(statuses.size === numSlaves + 1)
+ } else {
+ assert(statuses.size === 0)
}
}
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
// is true.
- def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- val expectedNumBlocks = if (removeFromDriver) 0 else 1
- val possiblyNot = if (removeFromDriver) "" else " not"
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- assert(statuses.size === expectedNumBlocks,
- "Block should%s be unpersisted on the driver".format(possiblyNot))
- }
+ def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) {
+ var blockId = BroadcastBlockId(broadcastId)
+ var expectedNumBlocks = if (removeFromDriver) 0 else 1
+ var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === expectedNumBlocks)
+
+ blockId = BroadcastBlockId(broadcastId, "piece0")
+ expectedNumBlocks = if (removeFromDriver || !distributed) 0 else 1
+ statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === expectedNumBlocks)
}
- testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation,
+ testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation,
afterUsingBroadcast, afterUnpersist, removeFromDriver)
}
@@ -262,10 +246,9 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
distributed: Boolean,
numSlaves: Int, // used only when distributed = true
broadcastConf: SparkConf,
- getBlockIds: Long => Seq[BroadcastBlockId],
- afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
- afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
- afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+ afterCreation: (Long, BlockManagerMaster) => Unit,
+ afterUsingBroadcast: (Long, BlockManagerMaster) => Unit,
+ afterUnpersist: (Long, BlockManagerMaster) => Unit,
removeFromDriver: Boolean) {
sc = if (distributed) {
@@ -278,15 +261,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
// Create broadcast variable
val broadcast = sc.broadcast(list)
- val blocks = getBlockIds(broadcast.id)
- afterCreation(blocks, blockManagerMaster)
+ afterCreation(broadcast.id, blockManagerMaster)
// Use broadcast variable on all executors
val partitions = 10
assert(partitions > numSlaves)
val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
- afterUsingBroadcast(blocks, blockManagerMaster)
+ afterUsingBroadcast(broadcast.id, blockManagerMaster)
// Unpersist broadcast
if (removeFromDriver) {
@@ -294,7 +276,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
} else {
broadcast.unpersist(blocking = true)
}
- afterUnpersist(blocks, blockManagerMaster)
+ afterUnpersist(broadcast.id, blockManagerMaster)
// If the broadcast is removed from driver, all subsequent uses of the broadcast variable
// should throw SparkExceptions. Otherwise, the result should be the same as before.