aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala43
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala30
3 files changed, 67 insertions, 21 deletions
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 75e64c1bf4..94142d3336 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -56,11 +56,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
/**
- * Value of the broadcast object. On driver, this is set directly by the constructor.
- * On executors, this is reconstructed by [[readObject]], which builds this value by reading
- * blocks from the driver and/or other executors.
+ * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]],
+ * which builds this value by reading blocks from the driver and/or other executors.
+ *
+ * On the driver, if the value is required, it is read lazily from the block manager.
*/
- @transient private var _value: T = obj
+ @transient private lazy val _value: T = readBroadcastBlock()
+
/** The compression codec to use, or None if compression is disabled */
@transient private var compressionCodec: Option[CompressionCodec] = _
/** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
@@ -79,22 +81,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
private val broadcastId = BroadcastBlockId(id)
/** Total number of blocks this broadcast variable contains. */
- private val numBlocks: Int = writeBlocks()
+ private val numBlocks: Int = writeBlocks(obj)
- override protected def getValue() = _value
+ override protected def getValue() = {
+ _value
+ }
/**
* Divide the object into multiple blocks and put those blocks in the block manager.
- *
+ * @param value the object to divide
* @return number of blocks this broadcast variable is divided into
*/
- private def writeBlocks(): Int = {
+ private def writeBlocks(value: T): Int = {
// Store a copy of the broadcast variable in the driver so that tasks run on the driver
// do not create a duplicate copy of the broadcast variable's value.
- SparkEnv.get.blockManager.putSingle(broadcastId, _value, StorageLevel.MEMORY_AND_DISK,
+ SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK,
tellMaster = false)
val blocks =
- TorrentBroadcast.blockifyObject(_value, blockSize, SparkEnv.get.serializer, compressionCodec)
+ TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
blocks.zipWithIndex.foreach { case (block, i) =>
SparkEnv.get.blockManager.putBytes(
BroadcastBlockId(id, "piece" + i),
@@ -157,31 +161,30 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
out.defaultWriteObject()
}
- /** Used by the JVM when deserializing this object. */
- private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
- in.defaultReadObject()
+ private def readBroadcastBlock(): T = Utils.tryOrIOException {
TorrentBroadcast.synchronized {
setConf(SparkEnv.get.conf)
SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
case Some(x) =>
- _value = x.asInstanceOf[T]
+ x.asInstanceOf[T]
case None =>
logInfo("Started reading broadcast variable " + id)
- val start = System.nanoTime()
+ val startTimeMs = System.currentTimeMillis()
val blocks = readBlocks()
- val time = (System.nanoTime() - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
+ logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
- _value =
- TorrentBroadcast.unBlockifyObject[T](blocks, SparkEnv.get.serializer, compressionCodec)
+ val obj = TorrentBroadcast.unBlockifyObject[T](
+ blocks, SparkEnv.get.serializer, compressionCodec)
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
SparkEnv.get.blockManager.putSingle(
- broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+ broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+ obj
}
}
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 4660030155..612eca308b 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -988,6 +988,21 @@ private[spark] object Utils extends Logging {
}
}
+ /**
+ * Execute a block of code that returns a value, re-throwing any non-fatal uncaught
+ * exceptions as IOException. This is used when implementing Externalizable and Serializable's
+ * read and write methods, since Java's serializer will not report non-IOExceptions properly;
+ * see SPARK-4080 for more context.
+ */
+ def tryOrIOException[T](block: => T): T = {
+ try {
+ block
+ } catch {
+ case e: IOException => throw e
+ case NonFatal(t) => throw new IOException(t)
+ }
+ }
+
/** Default filtering function for finding call sites using `getCallSite`. */
private def coreExclusionFunction(className: String): Boolean = {
// A regular expression to match classes of the "core" Spark API that we want to skip when
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index 1014fd62d9..b0a70f012f 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -21,11 +21,28 @@ import scala.util.Random
import org.scalatest.{Assertions, FunSuite}
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkEnv}
import org.apache.spark.io.SnappyCompressionCodec
+import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage._
+// Dummy class that creates a broadcast variable but doesn't use it
+class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable {
+ @transient val list = List(1, 2, 3, 4)
+ val broadcast = rdd.context.broadcast(list)
+ val bid = broadcast.id
+
+ def doSomething() = {
+ rdd.map { x =>
+ val bm = SparkEnv.get.blockManager
+ // Check if broadcast block was fetched
+ val isFound = bm.getLocal(BroadcastBlockId(bid)).isDefined
+ (x, isFound)
+ }.collect().toSet
+ }
+}
+
class BroadcastSuite extends FunSuite with LocalSparkContext {
private val httpConf = broadcastConf("HttpBroadcastFactory")
@@ -105,6 +122,17 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
}
}
+ test("Test Lazy Broadcast variables with TorrentBroadcast") {
+ val numSlaves = 2
+ val conf = torrentConf.clone
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf)
+ val rdd = sc.parallelize(1 to numSlaves)
+
+ val results = new DummyBroadcastClass(rdd).doSomething()
+
+ assert(results.toSet === (1 to numSlaves).map(x => (x, false)).toSet)
+ }
+
test("Unpersisting HttpBroadcast on executors only in local mode") {
testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false)
}