aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-08-19 22:11:13 -0700
committerReynold Xin <rxin@apache.org>2014-08-19 22:11:13 -0700
commit8adfbc2b6b5b647e450d30f89c141f935b6aa94b (patch)
treeb2277c0f05fd59d145be613dd028b21f674c51bd /core
parentfce5c0fb6384f3a142a4155525a5d62640725150 (diff)
downloadspark-8adfbc2b6b5b647e450d30f89c141f935b6aa94b.tar.gz
spark-8adfbc2b6b5b647e450d30f89c141f935b6aa94b.tar.bz2
spark-8adfbc2b6b5b647e450d30f89c141f935b6aa94b.zip
[SPARK-3119] Re-implementation of TorrentBroadcast.
This is a re-implementation of TorrentBroadcast, with the following changes: 1. Removes most of the mutable, transient state from TorrentBroadcast (e.g. totalBytes, num of blocks fetched). 2. Removes TorrentInfo and TorrentBlock 3. Replaces the BlockManager.getSingle call in readObject with a getLocal, resuling in one less RPC call to the BlockManagerMasterActor to find the location of the block. 4. Removes the metadata block, resulting in one less block to fetch. 5. Removes an extra memory copy for deserialization (by using Java's SequenceInputStream). Basically for a regular broadcasted object with only one block, the number of RPC calls goes from 5+1 to 2+1). Old TorrentBroadcast for object of a single block: 1 RPC to ask for location of the broadcast variable 1 RPC to ask for location of the metadata block 1 RPC to fetch the metadata block 1 RPC to ask for location of the first data block 1 RPC to fetch the first data block 1 RPC to tell the driver we put the first data block in i.e. 5 + 1 New TorrentBroadcast for object of a single block: 1 RPC to ask for location of the first data block 1 RPC to get the first data block 1 RPC to tell the driver we put the first data block in i.e. 2 + 1 Author: Reynold Xin <rxin@apache.org> Closes #2030 from rxin/torrentBroadcast and squashes the following commits: 5bacb9d [Reynold Xin] Always add the object to driver's block manager. 0d8ed5b [Reynold Xin] Added getBytes to BlockManager and uses that in TorrentBroadcast. 2d6a5fb [Reynold Xin] Use putBytes/getRemoteBytes throughout. 3670f00 [Reynold Xin] Code review feedback. c1185cd [Reynold Xin] [SPARK-3119] Re-implementation of TorrentBroadcast.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala282
-rw-r--r--core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala128
3 files changed, 181 insertions, 240 deletions
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index a8c827030a..6a187b4062 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -32,8 +32,19 @@ import org.apache.spark.annotation.DeveloperApi
*/
@DeveloperApi
trait BroadcastFactory {
+
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
+
+ /**
+ * Creates a new broadcast variable.
+ *
+ * @param value value to broadcast
+ * @param isLocal whether we are in local mode (single JVM process)
+ * @param id unique id representing this broadcast variable
+ */
def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]
+
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
+
def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index d8be649f96..6173fd3a69 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -18,50 +18,116 @@
package org.apache.spark.broadcast
import java.io._
+import java.nio.ByteBuffer
+import scala.collection.JavaConversions.asJavaEnumeration
import scala.reflect.ClassTag
import scala.util.Random
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
+import org.apache.spark.util.ByteBufferInputStream
/**
- * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
- * protocol to do a distributed transfer of the broadcasted data to the executors.
- * The mechanism is as follows. The driver divides the serializes the broadcasted data,
- * divides it into smaller chunks, and stores them in the BlockManager of the driver.
- * These chunks are reported to the BlockManagerMaster so that all the executors can
- * learn the location of those chunks. The first time the broadcast variable (sent as
- * part of task) is deserialized at a executor, all the chunks are fetched using
- * the BlockManager. When all the chunks are fetched (initially from the driver's
- * BlockManager), they are combined and deserialized to recreate the broadcasted data.
- * However, the chunks are also stored in the BlockManager and reported to the
- * BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns
- * multiple locations for each chunk. Hence, subsequent fetches of each chunk will be
- * made to other executors who already have those chunks, resulting in a distributed
- * fetching. This prevents the driver from being the bottleneck in sending out multiple
- * copies of the broadcast data (one per executor) as done by the
- * [[org.apache.spark.broadcast.HttpBroadcast]].
+ * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
+ *
+ * The mechanism is as follows:
+ *
+ * The driver divides the serialized object into small chunks and
+ * stores those chunks in the BlockManager of the driver.
+ *
+ * On each executor, the executor first attempts to fetch the object from its BlockManager. If
+ * it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or
+ * other executors if available. Once it gets the chunks, it puts the chunks in its own
+ * BlockManager, ready for other executors to fetch from.
+ *
+ * This prevents the driver from being the bottleneck in sending out multiple copies of the
+ * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]].
+ *
+ * @param obj object to broadcast
+ * @param isLocal whether Spark is running in local mode (single JVM process).
+ * @param id A unique identifier for the broadcast variable.
*/
private[spark] class TorrentBroadcast[T: ClassTag](
- @transient var value_ : T, isLocal: Boolean, id: Long)
+ obj : T,
+ @transient private val isLocal: Boolean,
+ id: Long)
extends Broadcast[T](id) with Logging with Serializable {
- override protected def getValue() = value_
+ /**
+ * Value of the broadcast object. On driver, this is set directly by the constructor.
+ * On executors, this is reconstructed by [[readObject]], which builds this value by reading
+ * blocks from the driver and/or other executors.
+ */
+ @transient private var _value: T = obj
private val broadcastId = BroadcastBlockId(id)
- SparkEnv.get.blockManager.putSingle(
- broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+ /** Total number of blocks this broadcast variable contains. */
+ private val numBlocks: Int = writeBlocks()
+
+ override protected def getValue() = _value
+
+ /**
+ * Divide the object into multiple blocks and put those blocks in the block manager.
+ *
+ * @return number of blocks this broadcast variable is divided into
+ */
+ private def writeBlocks(): Int = {
+ // For local mode, just put the object in the BlockManager so we can find it later.
+ SparkEnv.get.blockManager.putSingle(
+ broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+
+ if (!isLocal) {
+ val blocks = TorrentBroadcast.blockifyObject(_value)
+ blocks.zipWithIndex.foreach { case (block, i) =>
+ SparkEnv.get.blockManager.putBytes(
+ BroadcastBlockId(id, "piece" + i),
+ block,
+ StorageLevel.MEMORY_AND_DISK_SER,
+ tellMaster = true)
+ }
+ blocks.length
+ } else {
+ 0
+ }
+ }
+
+ /** Fetch torrent blocks from the driver and/or other executors. */
+ private def readBlocks(): Array[ByteBuffer] = {
+ // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
+ // to the driver, so other executors can pull these chunks from this executor as well.
+ val blocks = new Array[ByteBuffer](numBlocks)
+ val bm = SparkEnv.get.blockManager
- @transient private var arrayOfBlocks: Array[TorrentBlock] = null
- @transient private var totalBlocks = -1
- @transient private var totalBytes = -1
- @transient private var hasBlocks = 0
+ for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
+ val pieceId = BroadcastBlockId(id, "piece" + pid)
- if (!isLocal) {
- sendBroadcast()
+ // First try getLocalBytes because there is a chance that previous attempts to fetch the
+ // broadcast blocks have already fetched some of the blocks. In that case, some blocks
+ // would be available locally (on this executor).
+ var blockOpt = bm.getLocalBytes(pieceId)
+ if (!blockOpt.isDefined) {
+ blockOpt = bm.getRemoteBytes(pieceId)
+ blockOpt match {
+ case Some(block) =>
+ // If we found the block from remote executors/driver's BlockManager, put the block
+ // in this executor's BlockManager.
+ SparkEnv.get.blockManager.putBytes(
+ pieceId,
+ block,
+ StorageLevel.MEMORY_AND_DISK_SER,
+ tellMaster = true)
+
+ case None =>
+ throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
+ }
+ }
+ // If we get here, the option is defined.
+ blocks(pid) = blockOpt.get
+ }
+ blocks
}
/**
@@ -79,26 +145,6 @@ private[spark] class TorrentBroadcast[T: ClassTag](
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
}
- private def sendBroadcast() {
- val tInfo = TorrentBroadcast.blockifyObject(value_)
- totalBlocks = tInfo.totalBlocks
- totalBytes = tInfo.totalBytes
- hasBlocks = tInfo.totalBlocks
-
- // Store meta-info
- val metaId = BroadcastBlockId(id, "meta")
- val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
- SparkEnv.get.blockManager.putSingle(
- metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
-
- // Store individual pieces
- for (i <- 0 until totalBlocks) {
- val pieceId = BroadcastBlockId(id, "piece" + i)
- SparkEnv.get.blockManager.putSingle(
- pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
- }
- }
-
/** Used by the JVM when serializing this object. */
private def writeObject(out: ObjectOutputStream) {
assertValid()
@@ -109,99 +155,30 @@ private[spark] class TorrentBroadcast[T: ClassTag](
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
TorrentBroadcast.synchronized {
- SparkEnv.get.blockManager.getSingle(broadcastId) match {
+ SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
case Some(x) =>
- value_ = x.asInstanceOf[T]
+ _value = x.asInstanceOf[T]
case None =>
- val start = System.nanoTime
logInfo("Started reading broadcast variable " + id)
-
- // Initialize @transient variables that will receive garbage values from the master.
- resetWorkerVariables()
-
- if (receiveBroadcast()) {
- value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
-
- /* Store the merged copy in cache so that the next worker doesn't need to rebuild it.
- * This creates a trade-off between memory usage and latency. Storing copy doubles
- * the memory footprint; not storing doubles deserialization cost. Also,
- * this does not need to be reported to BlockManagerMaster since other executors
- * does not need to access this block (they only need to fetch the chunks,
- * which are reported).
- */
- SparkEnv.get.blockManager.putSingle(
- broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
-
- // Remove arrayOfBlocks from memory once value_ is on local cache
- resetWorkerVariables()
- } else {
- logError("Reading broadcast variable " + id + " failed")
- }
-
- val time = (System.nanoTime - start) / 1e9
+ val start = System.nanoTime()
+ val blocks = readBlocks()
+ val time = (System.nanoTime() - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
- }
- }
- }
-
- private def resetWorkerVariables() {
- arrayOfBlocks = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = 0
- }
-
- private def receiveBroadcast(): Boolean = {
- // Receive meta-info about the size of broadcast data,
- // the number of chunks it is divided into, etc.
- val metaId = BroadcastBlockId(id, "meta")
- var attemptId = 10
- while (attemptId > 0 && totalBlocks == -1) {
- SparkEnv.get.blockManager.getSingle(metaId) match {
- case Some(x) =>
- val tInfo = x.asInstanceOf[TorrentInfo]
- totalBlocks = tInfo.totalBlocks
- totalBytes = tInfo.totalBytes
- arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
- hasBlocks = 0
- case None =>
- Thread.sleep(500)
- }
- attemptId -= 1
- }
-
- if (totalBlocks == -1) {
- return false
- }
-
- /*
- * Fetch actual chunks of data. Note that all these chunks are stored in
- * the BlockManager and reported to the master, so that other executors
- * can find out and pull the chunks from this executor.
- */
- val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
- for (pid <- recvOrder) {
- val pieceId = BroadcastBlockId(id, "piece" + pid)
- SparkEnv.get.blockManager.getSingle(pieceId) match {
- case Some(x) =>
- arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
- hasBlocks += 1
+ _value = TorrentBroadcast.unBlockifyObject[T](blocks)
+ // Store the merged copy in BlockManager so other tasks on this executor don't
+ // need to re-fetch it.
SparkEnv.get.blockManager.putSingle(
- pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
-
- case None =>
- throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
+ broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
}
}
-
- hasBlocks == totalBlocks
}
-
}
-private[broadcast] object TorrentBroadcast extends Logging {
+
+private object TorrentBroadcast extends Logging {
+ /** Size of each block. Default value is 4MB. */
private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
private var initialized = false
private var conf: SparkConf = null
@@ -223,7 +200,9 @@ private[broadcast] object TorrentBroadcast extends Logging {
initialized = false
}
- def blockifyObject[T: ClassTag](obj: T): TorrentInfo = {
+ def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = {
+ // TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks
+ // so we don't need to do the extra memory copy.
val bos = new ByteArrayOutputStream()
val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
val ser = SparkEnv.get.serializer.newInstance()
@@ -231,44 +210,27 @@ private[broadcast] object TorrentBroadcast extends Logging {
serOut.writeObject[T](obj).close()
val byteArray = bos.toByteArray
val bais = new ByteArrayInputStream(byteArray)
+ val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt
+ val blocks = new Array[ByteBuffer](numBlocks)
- var blockNum = byteArray.length / BLOCK_SIZE
- if (byteArray.length % BLOCK_SIZE != 0) {
- blockNum += 1
- }
-
- val blocks = new Array[TorrentBlock](blockNum)
var blockId = 0
-
for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
val tempByteArray = new Array[Byte](thisBlockSize)
bais.read(tempByteArray, 0, thisBlockSize)
- blocks(blockId) = new TorrentBlock(blockId, tempByteArray)
+ blocks(blockId) = ByteBuffer.wrap(tempByteArray)
blockId += 1
}
bais.close()
-
- val info = TorrentInfo(blocks, blockNum, byteArray.length)
- info.hasBlocks = blockNum
- info
+ blocks
}
- def unBlockifyObject[T: ClassTag](
- arrayOfBlocks: Array[TorrentBlock],
- totalBytes: Int,
- totalBlocks: Int): T = {
- val retByteArray = new Array[Byte](totalBytes)
- for (i <- 0 until totalBlocks) {
- System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
- i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
- }
+ def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = {
+ val is = new SequenceInputStream(
+ asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block))))
+ val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is
- val in: InputStream = {
- val arrIn = new ByteArrayInputStream(retByteArray)
- if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn
- }
val ser = SparkEnv.get.serializer.newInstance()
val serIn = ser.deserializeStream(in)
val obj = serIn.readObject[T]()
@@ -284,17 +246,3 @@ private[broadcast] object TorrentBroadcast extends Logging {
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
}
}
-
-private[broadcast] case class TorrentBlock(
- blockID: Int,
- byteArray: Array[Byte])
- extends Serializable
-
-private[broadcast] case class TorrentInfo(
- @transient arrayOfBlocks: Array[TorrentBlock],
- totalBlocks: Int,
- totalBytes: Int)
- extends Serializable {
-
- @transient var hasBlocks = 0
-}
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index 17c64455b2..978a6ded80 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -17,10 +17,12 @@
package org.apache.spark.broadcast
-import org.apache.spark.storage.{BroadcastBlockId, _}
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
import org.scalatest.FunSuite
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
+import org.apache.spark.storage._
+
+
class BroadcastSuite extends FunSuite with LocalSparkContext {
private val httpConf = broadcastConf("HttpBroadcastFactory")
@@ -124,12 +126,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
val numSlaves = if (distributed) 2 else 0
- def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id))
-
// Verify that the broadcast file is created, and blocks are persisted only on the driver
- def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) {
+ val blockId = BroadcastBlockId(broadcastId)
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === 1)
statuses.head match { case (bm, status) =>
assert(bm.executorId === "<driver>", "Block should only be on the driver")
@@ -139,14 +139,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
}
if (distributed) {
// this file is only generated in distributed mode
- assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!")
+ assert(HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file not found!")
}
}
// Verify that blocks are persisted in both the executors and the driver
- def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
+ val blockId = BroadcastBlockId(broadcastId)
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === numSlaves + 1)
statuses.foreach { case (_, status) =>
assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
@@ -157,21 +157,21 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
// is true. In the latter case, also verify that the broadcast file is deleted on the driver.
- def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) {
+ val blockId = BroadcastBlockId(broadcastId)
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
val expectedNumBlocks = if (removeFromDriver) 0 else 1
val possiblyNot = if (removeFromDriver) "" else " not"
assert(statuses.size === expectedNumBlocks,
"Block should%s be unpersisted on the driver".format(possiblyNot))
if (distributed && removeFromDriver) {
// this file is only generated in distributed mode
- assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists,
+ assert(!HttpBroadcast.getFile(blockId.broadcastId).exists,
"Broadcast file should%s be deleted".format(possiblyNot))
}
}
- testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation,
+ testUnpersistBroadcast(distributed, numSlaves, httpConf, afterCreation,
afterUsingBroadcast, afterUnpersist, removeFromDriver)
}
@@ -185,67 +185,51 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
val numSlaves = if (distributed) 2 else 0
- def getBlockIds(id: Long) = {
- val broadcastBlockId = BroadcastBlockId(id)
- val metaBlockId = BroadcastBlockId(id, "meta")
- // Assume broadcast value is small enough to fit into 1 piece
- val pieceBlockId = BroadcastBlockId(id, "piece0")
- if (distributed) {
- // the metadata and piece blocks are generated only in distributed mode
- Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId)
- } else {
- Seq[BroadcastBlockId](broadcastBlockId)
- }
+ // Verify that blocks are persisted only on the driver
+ def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) {
+ var blockId = BroadcastBlockId(broadcastId)
+ var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === 1)
+
+ blockId = BroadcastBlockId(broadcastId, "piece0")
+ statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === (if (distributed) 1 else 0))
}
- // Verify that blocks are persisted only on the driver
- def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ // Verify that blocks are persisted in both the executors and the driver
+ def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
+ var blockId = BroadcastBlockId(broadcastId)
+ var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ if (distributed) {
+ assert(statuses.size === numSlaves + 1)
+ } else {
assert(statuses.size === 1)
- statuses.head match { case (bm, status) =>
- assert(bm.executorId === "<driver>", "Block should only be on the driver")
- assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
- assert(status.memSize > 0, "Block should be in memory store on the driver")
- assert(status.diskSize === 0, "Block should not be in disk store on the driver")
- }
}
- }
- // Verify that blocks are persisted in both the executors and the driver
- def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- if (blockId.field == "meta") {
- // Meta data is only on the driver
- assert(statuses.size === 1)
- statuses.head match { case (bm, _) => assert(bm.executorId === "<driver>") }
- } else {
- // Other blocks are on both the executors and the driver
- assert(statuses.size === numSlaves + 1,
- blockId + " has " + statuses.size + " statuses: " + statuses.mkString(","))
- statuses.foreach { case (_, status) =>
- assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
- assert(status.memSize > 0, "Block should be in memory store")
- assert(status.diskSize === 0, "Block should not be in disk store")
- }
- }
+ blockId = BroadcastBlockId(broadcastId, "piece0")
+ statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ if (distributed) {
+ assert(statuses.size === numSlaves + 1)
+ } else {
+ assert(statuses.size === 0)
}
}
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
// is true.
- def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- val expectedNumBlocks = if (removeFromDriver) 0 else 1
- val possiblyNot = if (removeFromDriver) "" else " not"
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- assert(statuses.size === expectedNumBlocks,
- "Block should%s be unpersisted on the driver".format(possiblyNot))
- }
+ def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) {
+ var blockId = BroadcastBlockId(broadcastId)
+ var expectedNumBlocks = if (removeFromDriver) 0 else 1
+ var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === expectedNumBlocks)
+
+ blockId = BroadcastBlockId(broadcastId, "piece0")
+ expectedNumBlocks = if (removeFromDriver || !distributed) 0 else 1
+ statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === expectedNumBlocks)
}
- testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation,
+ testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation,
afterUsingBroadcast, afterUnpersist, removeFromDriver)
}
@@ -262,10 +246,9 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
distributed: Boolean,
numSlaves: Int, // used only when distributed = true
broadcastConf: SparkConf,
- getBlockIds: Long => Seq[BroadcastBlockId],
- afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
- afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
- afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+ afterCreation: (Long, BlockManagerMaster) => Unit,
+ afterUsingBroadcast: (Long, BlockManagerMaster) => Unit,
+ afterUnpersist: (Long, BlockManagerMaster) => Unit,
removeFromDriver: Boolean) {
sc = if (distributed) {
@@ -278,15 +261,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
// Create broadcast variable
val broadcast = sc.broadcast(list)
- val blocks = getBlockIds(broadcast.id)
- afterCreation(blocks, blockManagerMaster)
+ afterCreation(broadcast.id, blockManagerMaster)
// Use broadcast variable on all executors
val partitions = 10
assert(partitions > numSlaves)
val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
- afterUsingBroadcast(blocks, blockManagerMaster)
+ afterUsingBroadcast(broadcast.id, blockManagerMaster)
// Unpersist broadcast
if (removeFromDriver) {
@@ -294,7 +276,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
} else {
broadcast.unpersist(blocking = true)
}
- afterUnpersist(blocks, blockManagerMaster)
+ afterUnpersist(broadcast.id, blockManagerMaster)
// If the broadcast is removed from driver, all subsequent uses of the broadcast variable
// should throw SparkExceptions. Otherwise, the result should be the same as before.