aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala32
1 files changed, 32 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index e8d6d587b4..f350784378 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -19,6 +19,7 @@ package org.apache.spark.broadcast
import java.io._
import java.nio.ByteBuffer
+import java.util.zip.Adler32
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
@@ -77,6 +78,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
}
// Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided
blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024
+ checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true)
}
setConf(SparkEnv.get.conf)
@@ -85,10 +87,27 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
/** Total number of blocks this broadcast variable contains. */
private val numBlocks: Int = writeBlocks(obj)
+ /** Whether to generate checksum for blocks or not. */
+ private var checksumEnabled: Boolean = false
+ /** The checksum for all the blocks. */
+ private var checksums: Array[Int] = _
+
override protected def getValue() = {
_value
}
+ private def calcChecksum(block: ByteBuffer): Int = {
+ val adler = new Adler32()
+ if (block.hasArray) {
+ adler.update(block.array, block.arrayOffset + block.position, block.limit - block.position)
+ } else {
+ val bytes = new Array[Byte](block.remaining())
+ block.duplicate.get(bytes)
+ adler.update(bytes)
+ }
+ adler.getValue.toInt
+ }
+
/**
* Divide the object into multiple blocks and put those blocks in the block manager.
*
@@ -105,7 +124,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
}
val blocks =
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
+ if (checksumEnabled) {
+ checksums = new Array[Int](blocks.length)
+ }
blocks.zipWithIndex.foreach { case (block, i) =>
+ if (checksumEnabled) {
+ checksums(i) = calcChecksum(block)
+ }
val pieceId = BroadcastBlockId(id, "piece" + i)
val bytes = new ChunkedByteBuffer(block.duplicate())
if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
@@ -135,6 +160,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
case None =>
bm.getRemoteBytes(pieceId) match {
case Some(b) =>
+ if (checksumEnabled) {
+ val sum = calcChecksum(b.chunks(0))
+ if (sum != checksums(pid)) {
+ throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +
+ s" $sum != ${checksums(pid)}")
+ }
+ }
// We found the block from remote executors/driver's BlockManager, so put the block
// in this executor's BlockManager.
if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {