diff options
7 files changed, 328 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala new file mode 100644 index 0000000000..073a0a5029 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io._ + +import scala.math +import scala.util.Random + +import org.apache.spark._ +import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} +import org.apache.spark.util.Utils + + +private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +extends Broadcast[T](id) with Logging with Serializable { + + def value = value_ + + def broadcastId = BroadcastBlockId(id) + + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + } + + @transient var arrayOfBlocks: Array[TorrentBlock] = null + @transient var totalBlocks = -1 + @transient var totalBytes = -1 + @transient var hasBlocks = 0 + + if (!isLocal) { + sendBroadcast() + } + + def sendBroadcast() { + var tInfo = TorrentBroadcast.blockifyObject(value_) + + totalBlocks = tInfo.totalBlocks + totalBytes = tInfo.totalBytes + hasBlocks = tInfo.totalBlocks + + // Store meta-info + val metaId = BroadcastHelperBlockId(broadcastId, "meta") + val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle( + metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true) + } + + // Store individual pieces + for (i <- 0 until totalBlocks) { + val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle( + pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) + } + } + } + + // Called by JVM when deserializing an object + private def readObject(in: ObjectInputStream) { + in.defaultReadObject() + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(broadcastId) match { + case Some(x) => + value_ = x.asInstanceOf[T] + + case None => + val start = System.nanoTime + logInfo("Started reading broadcast variable " + id) + + // Initialize @transient variables that will receive garbage values from the master. + resetWorkerVariables() + + if (receiveBroadcast(id)) { + value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + + // Store the merged copy in cache so that the next worker doesn't need to rebuild it. + // This creates a tradeoff between memory usage and latency. + // Storing copy doubles the memory footprint; not storing doubles deserialization cost. + SparkEnv.get.blockManager.putSingle( + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + + // Remove arrayOfBlocks from memory once value_ is on local cache + resetWorkerVariables() + } else { + logError("Reading broadcast variable " + id + " failed") + } + + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading broadcast variable " + id + " took " + time + " s") + } + } + } + + private def resetWorkerVariables() { + arrayOfBlocks = null + totalBytes = -1 + totalBlocks = -1 + hasBlocks = 0 + } + + def receiveBroadcast(variableID: Long): Boolean = { + // Receive meta-info + val metaId = BroadcastHelperBlockId(broadcastId, "meta") + var attemptId = 10 + while (attemptId > 0 && totalBlocks == -1) { + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(metaId) match { + case Some(x) => + val tInfo = x.asInstanceOf[TorrentInfo] + totalBlocks = tInfo.totalBlocks + totalBytes = tInfo.totalBytes + arrayOfBlocks = new Array[TorrentBlock](totalBlocks) + hasBlocks = 0 + + case None => + Thread.sleep(500) + } + } + attemptId -= 1 + } + if (totalBlocks == -1) { + return false + } + + // Receive actual blocks + val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) + for (pid <- recvOrder) { + val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(pieceId) match { + case Some(x) => + arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] + hasBlocks += 1 + SparkEnv.get.blockManager.putSingle( + pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true) + + case None => + throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) + } + } + } + + (hasBlocks == totalBlocks) + } + +} + +private object TorrentBroadcast +extends Logging { + + private var initialized = false + + def initialize(_isDriver: Boolean) { + synchronized { + if (!initialized) { + initialized = true + } + } + } + + def stop() { + initialized = false + } + + val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024 + + def blockifyObject[T](obj: T): TorrentInfo = { + val byteArray = Utils.serialize[T](obj) + val bais = new ByteArrayInputStream(byteArray) + + var blockNum = (byteArray.length / BLOCK_SIZE) + if (byteArray.length % BLOCK_SIZE != 0) + blockNum += 1 + + var retVal = new Array[TorrentBlock](blockNum) + var blockID = 0 + + for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { + val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) + var tempByteArray = new Array[Byte](thisBlockSize) + val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + + retVal(blockID) = new TorrentBlock(blockID, tempByteArray) + blockID += 1 + } + bais.close() + + var tInfo = TorrentInfo(retVal, blockNum, byteArray.length) + tInfo.hasBlocks = blockNum + + return tInfo + } + + def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock], + totalBytes: Int, + totalBlocks: Int): T = { + var retByteArray = new Array[Byte](totalBytes) + for (i <- 0 until totalBlocks) { + System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, + i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length) + } + Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader) + } + +} + +private[spark] case class TorrentBlock( + blockID: Int, + byteArray: Array[Byte]) + extends Serializable + +private[spark] case class TorrentInfo( + @transient arrayOfBlocks : Array[TorrentBlock], + totalBlocks: Int, + totalBytes: Int) + extends Serializable { + + @transient var hasBlocks = 0 +} + +private[spark] class TorrentBroadcastFactory + extends BroadcastFactory { + + def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new TorrentBroadcast[T](value_, isLocal, id) + + def stop() { TorrentBroadcast.stop() } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index c7efc67a4a..7156d855d8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -32,7 +32,7 @@ private[spark] sealed abstract class BlockId { def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None def isRDD = isInstanceOf[RDDBlockId] def isShuffle = isInstanceOf[ShuffleBlockId] - def isBroadcast = isInstanceOf[BroadcastBlockId] + def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId] override def toString = name override def hashCode = name.hashCode @@ -55,6 +55,10 @@ private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { def name = "broadcast_" + broadcastId } +private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { + def name = broadcastId.name + "_" + hType +} + private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId { def name = "taskresult_" + taskId } @@ -72,6 +76,7 @@ private[spark] object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r val BROADCAST = "broadcast_([0-9]+)".r + val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r val TEST = "test_(.*)".r @@ -84,6 +89,8 @@ private[spark] object BlockId { ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) case BROADCAST(broadcastId) => BroadcastBlockId(broadcastId.toLong) + case BROADCAST_HELPER(broadcastId, hType) => + BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType) case TASKRESULT(taskId) => TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 801f88a3db..c67a61515e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -21,6 +21,7 @@ import java.io.{InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet} +import scala.util.Random import akka.actor.{ActorSystem, Cancellable, Props} import akka.dispatch.{Await, Future} @@ -269,7 +270,7 @@ private[spark] class BlockManager( } /** - * Actually send a UpdateBlockInfo message. Returns the mater's response, + * Actually send a UpdateBlockInfo message. Returns the master's response, * which will be true if the block was successfully recorded and false if * the slave needs to re-register. */ @@ -478,7 +479,7 @@ private[spark] class BlockManager( } logDebug("Getting remote block " + blockId) // Get locations of block - val locations = master.getLocations(blockId) + val locations = Random.shuffle(master.getLocations(blockId)) // Get block from remote locations for (loc <- locations) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 633230c0a8..f8cf14b503 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -227,9 +227,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - if (id.executorId == "<driver>" && !isLocal) { - // Got a register message from the master node; don't register it - } else if (!blockManagerInfo.contains(id)) { + if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { case Some(manager) => // A block manager of the same executor already exists. diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index b3a53d928b..e022accee6 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -20,8 +20,42 @@ package org.apache.spark import org.scalatest.FunSuite class BroadcastSuite extends FunSuite with LocalSparkContext { - - test("basic broadcast") { + + override def afterEach() { + super.afterEach() + System.clearProperty("spark.broadcast.factory") + } + + test("Using HttpBroadcast locally") { + System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + sc = new SparkContext("local", "test") + val list = List(1, 2, 3, 4) + val listBroadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === Set((1, 10), (2, 10))) + } + + test("Accessing HttpBroadcast variables from multiple threads") { + System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + sc = new SparkContext("local[10]", "test") + val list = List(1, 2, 3, 4) + val listBroadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + } + + test("Accessing HttpBroadcast variables in a local cluster") { + System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + val numSlaves = 4 + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") + val list = List(1, 2, 3, 4) + val listBroadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + + test("Using TorrentBroadcast locally") { + System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") sc = new SparkContext("local", "test") val list = List(1, 2, 3, 4) val listBroadcast = sc.broadcast(list) @@ -29,11 +63,23 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(results.collect.toSet === Set((1, 10), (2, 10))) } - test("broadcast variables accessed in multiple threads") { + test("Accessing TorrentBroadcast variables from multiple threads") { + System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") sc = new SparkContext("local[10]", "test") val list = List(1, 2, 3, 4) val listBroadcast = sc.broadcast(list) val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) } + + test("Accessing TorrentBroadcast variables in a local cluster") { + System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") + val numSlaves = 4 + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") + val list = List(1, 2, 3, 4) + val listBroadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + } diff --git a/docs/configuration.md b/docs/configuration.md index 7940d41a27..c5900d0e09 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -319,6 +319,14 @@ Apart from these, the following properties are also available, and may be useful Should be greater than or equal to 1. Number of allowed retries = this value - 1. </td> </tr> +<tr> + <td>spark.broadcast.blockSize</td> + <td>4096</td> + <td> + Size of each piece of a block in kilobytes for <code>TorrentBroadcastFactory</code>. + Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, <code>BlockManager</code> might take a performance hit. + </td> +</tr> </table> diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 868ff81f67..529709c2f9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -22,12 +22,19 @@ import org.apache.spark.SparkContext object BroadcastTest { def main(args: Array[String]) { if (args.length == 0) { - System.err.println("Usage: BroadcastTest <master> [<slices>] [numElem]") + System.err.println("Usage: BroadcastTest <master> [slices] [numElem] [broadcastAlgo] [blockSize]") System.exit(1) } - val sc = new SparkContext(args(0), "Broadcast Test", + val bcName = if (args.length > 3) args(3) else "Http" + val blockSize = if (args.length > 4) args(4) else "4096" + + System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast." + bcName + "BroadcastFactory") + System.setProperty("spark.broadcast.blockSize", blockSize) + + val sc = new SparkContext(args(0), "Broadcast Test 2", System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val slices = if (args.length > 1) args(1).toInt else 2 val num = if (args.length > 2) args(2).toInt else 1000000 @@ -36,13 +43,15 @@ object BroadcastTest { arr1(i) = i } - for (i <- 0 until 2) { + for (i <- 0 until 3) { println("Iteration " + i) println("===========") + val startTime = System.nanoTime val barr1 = sc.broadcast(arr1) sc.parallelize(1 to 10, slices).foreach { i => println(barr1.value.size) } + println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6)) } System.exit(0) |