aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org')
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala136
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala11
2 files changed, 70 insertions, 77 deletions
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 42d58682a1..99af2e9608 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -26,6 +26,7 @@ import scala.util.Random
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
+import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.io.ByteArrayChunkOutputStream
@@ -46,14 +47,12 @@ import org.apache.spark.util.io.ByteArrayChunkOutputStream
* This prevents the driver from being the bottleneck in sending out multiple copies of the
* broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]].
*
+ * When initialized, TorrentBroadcast objects read SparkEnv.get.conf.
+ *
* @param obj object to broadcast
- * @param isLocal whether Spark is running in local mode (single JVM process).
* @param id A unique identifier for the broadcast variable.
*/
-private[spark] class TorrentBroadcast[T: ClassTag](
- obj : T,
- @transient private val isLocal: Boolean,
- id: Long)
+private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
/**
@@ -62,6 +61,20 @@ private[spark] class TorrentBroadcast[T: ClassTag](
* blocks from the driver and/or other executors.
*/
@transient private var _value: T = obj
+ /** The compression codec to use, or None if compression is disabled */
+ @transient private var compressionCodec: Option[CompressionCodec] = _
+ /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
+ @transient private var blockSize: Int = _
+
+ private def setConf(conf: SparkConf) {
+ compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
+ Some(CompressionCodec.createCodec(conf))
+ } else {
+ None
+ }
+ blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
+ }
+ setConf(SparkEnv.get.conf)
private val broadcastId = BroadcastBlockId(id)
@@ -76,23 +89,20 @@ private[spark] class TorrentBroadcast[T: ClassTag](
* @return number of blocks this broadcast variable is divided into
*/
private def writeBlocks(): Int = {
- // For local mode, just put the object in the BlockManager so we can find it later.
- SparkEnv.get.blockManager.putSingle(
- broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
-
- if (!isLocal) {
- val blocks = TorrentBroadcast.blockifyObject(_value)
- blocks.zipWithIndex.foreach { case (block, i) =>
- SparkEnv.get.blockManager.putBytes(
- BroadcastBlockId(id, "piece" + i),
- block,
- StorageLevel.MEMORY_AND_DISK_SER,
- tellMaster = true)
- }
- blocks.length
- } else {
- 0
+ // Store a copy of the broadcast variable in the driver so that tasks run on the driver
+ // do not create a duplicate copy of the broadcast variable's value.
+ SparkEnv.get.blockManager.putSingle(broadcastId, _value, StorageLevel.MEMORY_AND_DISK,
+ tellMaster = false)
+ val blocks =
+ TorrentBroadcast.blockifyObject(_value, blockSize, SparkEnv.get.serializer, compressionCodec)
+ blocks.zipWithIndex.foreach { case (block, i) =>
+ SparkEnv.get.blockManager.putBytes(
+ BroadcastBlockId(id, "piece" + i),
+ block,
+ StorageLevel.MEMORY_AND_DISK_SER,
+ tellMaster = true)
}
+ blocks.length
}
/** Fetch torrent blocks from the driver and/or other executors. */
@@ -104,29 +114,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
-
- // First try getLocalBytes because there is a chance that previous attempts to fetch the
+ logDebug(s"Reading piece $pieceId of $broadcastId")
+ // First try getLocalBytes because there is a chance that previous attempts to fetch the
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
// would be available locally (on this executor).
- var blockOpt = bm.getLocalBytes(pieceId)
- if (!blockOpt.isDefined) {
- blockOpt = bm.getRemoteBytes(pieceId)
- blockOpt match {
- case Some(block) =>
- // If we found the block from remote executors/driver's BlockManager, put the block
- // in this executor's BlockManager.
- SparkEnv.get.blockManager.putBytes(
- pieceId,
- block,
- StorageLevel.MEMORY_AND_DISK_SER,
- tellMaster = true)
-
- case None =>
- throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
- }
+ def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId)
+ def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block =>
+ // If we found the block from remote executors/driver's BlockManager, put the block
+ // in this executor's BlockManager.
+ SparkEnv.get.blockManager.putBytes(
+ pieceId,
+ block,
+ StorageLevel.MEMORY_AND_DISK_SER,
+ tellMaster = true)
+ block
}
- // If we get here, the option is defined.
- blocks(pid) = blockOpt.get
+ val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse(
+ throw new SparkException(s"Failed to get $pieceId of $broadcastId"))
+ blocks(pid) = block
}
blocks
}
@@ -156,6 +161,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
TorrentBroadcast.synchronized {
+ setConf(SparkEnv.get.conf)
SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
case Some(x) =>
_value = x.asInstanceOf[T]
@@ -167,7 +173,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](
val time = (System.nanoTime() - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
- _value = TorrentBroadcast.unBlockifyObject[T](blocks)
+ _value =
+ TorrentBroadcast.unBlockifyObject[T](blocks, SparkEnv.get.serializer, compressionCodec)
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
SparkEnv.get.blockManager.putSingle(
@@ -179,43 +186,29 @@ private[spark] class TorrentBroadcast[T: ClassTag](
private object TorrentBroadcast extends Logging {
- /** Size of each block. Default value is 4MB. */
- private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
- private var initialized = false
- private var conf: SparkConf = null
- private var compress: Boolean = false
- private var compressionCodec: CompressionCodec = null
-
- def initialize(_isDriver: Boolean, conf: SparkConf) {
- TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests
- synchronized {
- if (!initialized) {
- compress = conf.getBoolean("spark.broadcast.compress", true)
- compressionCodec = CompressionCodec.createCodec(conf)
- initialized = true
- }
- }
- }
- def stop() {
- initialized = false
- }
-
- def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = {
- val bos = new ByteArrayChunkOutputStream(BLOCK_SIZE)
- val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
- val ser = SparkEnv.get.serializer.newInstance()
+ def blockifyObject[T: ClassTag](
+ obj: T,
+ blockSize: Int,
+ serializer: Serializer,
+ compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
+ val bos = new ByteArrayChunkOutputStream(blockSize)
+ val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos)
+ val ser = serializer.newInstance()
val serOut = ser.serializeStream(out)
serOut.writeObject[T](obj).close()
bos.toArrays.map(ByteBuffer.wrap)
}
- def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = {
+ def unBlockifyObject[T: ClassTag](
+ blocks: Array[ByteBuffer],
+ serializer: Serializer,
+ compressionCodec: Option[CompressionCodec]): T = {
+ require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
val is = new SequenceInputStream(
asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block))))
- val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is
-
- val ser = SparkEnv.get.serializer.newInstance()
+ val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
+ val ser = serializer.newInstance()
val serIn = ser.deserializeStream(in)
val obj = serIn.readObject[T]()
serIn.close()
@@ -227,6 +220,7 @@ private object TorrentBroadcast extends Logging {
* If removeFromDriver is true, also remove these persisted blocks on the driver.
*/
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = {
+ logDebug(s"Unpersisting TorrentBroadcast $id")
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
}
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
index ad0f701d7a..fb024c1209 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
@@ -28,14 +28,13 @@ import org.apache.spark.{SecurityManager, SparkConf}
*/
class TorrentBroadcastFactory extends BroadcastFactory {
- override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
- TorrentBroadcast.initialize(isDriver, conf)
- }
+ override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { }
- override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
- new TorrentBroadcast[T](value_, isLocal, id)
+ override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = {
+ new TorrentBroadcast[T](value_, id)
+ }
- override def stop() { TorrentBroadcast.stop() }
+ override def stop() { }
/**
* Remove all persisted state associated with the torrent broadcast with the given ID.