From 815ecd349ad296adf2f85e67a06ff83ced24288f Mon Sep 17 00:00:00 2001 From: Mosharaf Chowdhury Date: Tue, 30 Nov 2010 18:08:49 -0800 Subject: Made Broadcast Pluggable. Finally! --- conf/java-opts | 2 +- src/scala/spark/BitTorrentBroadcast.scala | 14 +++- src/scala/spark/Broadcast.scala | 121 +++++++++++++++++------------- src/scala/spark/DfsBroadcast.scala | 13 +++- src/scala/spark/SparkContext.scala | 4 +- 5 files changed, 95 insertions(+), 59 deletions(-) diff --git a/conf/java-opts b/conf/java-opts index 7fb9e50bbc..c4f9e48276 100644 --- a/conf/java-opts +++ b/conf/java-opts @@ -1 +1 @@ --Dspark.broadcast.MasterHostAddress=127.0.0.1 -Dspark.broadcast.MasterTrackerPort=11111 -Dspark.broadcast.BlockSize=256 -Dspark.broadcast.MaxRetryCount=2 -Dspark.broadcast.TrackerSocketTimeout=50000 -Dspark.broadcast.ServerSocketTimeout=10000 -Dspark.broadcast.MaxChatTime=500 -Dspark.broadcast.EndGameFraction=0.95 +-Dspark.broadcast.MasterHostAddress=127.0.0.1 -Dspark.broadcast.MasterTrackerPort=11111 -Dspark.broadcast.BlockSize=256 -Dspark.broadcast.MaxRetryCount=2 -Dspark.broadcast.TrackerSocketTimeout=50000 -Dspark.broadcast.ServerSocketTimeout=10000 -Dspark.broadcast.MaxChatTime=500 -Dspark.broadcast.EndGameFraction=0.95 -Dspark.broadcast.Factory=spark.BitTorrentBroadcastFactory diff --git a/src/scala/spark/BitTorrentBroadcast.scala b/src/scala/spark/BitTorrentBroadcast.scala index e029108b09..e8432f9143 100644 --- a/src/scala/spark/BitTorrentBroadcast.scala +++ b/src/scala/spark/BitTorrentBroadcast.scala @@ -9,7 +9,7 @@ import scala.collection.mutable.{ListBuffer, Map, Set} @serializable class BitTorrentBroadcast[T] (@transient var value_ : T, isLocal: Boolean) -extends Broadcast with Logging { +extends Broadcast[T] with Logging { def value = value_ @@ -1028,6 +1028,13 @@ extends Broadcast with Logging { } } +class BitTorrentBroadcastFactory +extends BroadcastFactory { + def initialize (isMaster: Boolean) = BitTorrentBroadcast.initialize (isMaster) + def newBroadcast[T] (value_ : T, isLocal: Boolean) = + new BitTorrentBroadcast[T] (value_, isLocal) +} + private object BitTorrentBroadcast extends Logging { val values = Cache.newKeySpace() @@ -1115,7 +1122,10 @@ extends Logging { trackMV.start logInfo ("TrackMultipleValues started...") } - + + // Initialize DfsBroadcast to be used for broadcast variable persistence + DfsBroadcast.initialize + initialized = true } } diff --git a/src/scala/spark/Broadcast.scala b/src/scala/spark/Broadcast.scala index 75131a9981..cdb1de16db 100644 --- a/src/scala/spark/Broadcast.scala +++ b/src/scala/spark/Broadcast.scala @@ -4,16 +4,84 @@ import java.util.{BitSet, UUID} import java.util.concurrent.{Executors, ThreadPoolExecutor, ThreadFactory} @serializable -trait Broadcast { +trait Broadcast[T] { val uuid = UUID.randomUUID + def value: T + // We cannot have an abstract readObject here due to some weird issues with // readObject having to be 'private' in sub-classes. Possibly a Scala bug! + def sendBroadcast: Unit override def toString = "spark.Broadcast(" + uuid + ")" } +trait BroadcastFactory { + def initialize (isMaster: Boolean): Unit + def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T] +} + +private object Broadcast +extends Logging { + private var initialized = false + private var broadcastFactory: BroadcastFactory = null + + // Called by SparkContext or Executor before using Broadcast + def initialize (isMaster: Boolean): Unit = { + if (!initialized) { + val broadcastFactoryClass = System.getProperty("spark.broadcast.Factory", + "spark.BitTorrentBroadcastFactory") + val booleanArgs = Array[AnyRef] (isMaster.asInstanceOf[AnyRef]) +// broadcastFactory = Class.forName(broadcastFactoryClass).getConstructors()(0).newInstance(booleanArgs:_*).asInstanceOf[BroadcastFactory] + broadcastFactory = Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + + // Initialize appropriate BroadcastFactory and BroadcastObject + broadcastFactory.initialize(isMaster) + + initialized = true + } + } + + def getBroadcastFactory: BroadcastFactory = { + if (broadcastFactory == null) { + throw new SparkException ("Broadcast.getBroadcastFactory called before initialize") + } + broadcastFactory + } + + // Returns a standard ThreadFactory except all threads are daemons + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread (r) + t.setDaemon (true) + return t + } + } + } + + // Wrapper over newCachedThreadPool + def newDaemonCachedThreadPool: ThreadPoolExecutor = { + var threadPool = + Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory (newDaemonThreadFactory) + + return threadPool + } + + // Wrapper over newFixedThreadPool + def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory (newDaemonThreadFactory) + + return threadPool + } +} + @serializable case class SourceInfo (val hostAddress: String, val listenPort: Int, val totalBlocks: Int, val totalBytes: Int) @@ -69,54 +137,3 @@ class SpeedTracker { override def toString = sourceToSpeedMap.toString } - -private object Broadcast -extends Logging { - private var initialized = false - - // Called by SparkContext or Executor before using Broadcast - // Calls all other initializers here - def initialize (isMaster: Boolean): Unit = { - synchronized { - if (!initialized) { - // Initialization for DfsBroadcast - DfsBroadcast.initialize - // Initialization for BitTorrentBroadcast - BitTorrentBroadcast.initialize (isMaster) - - initialized = true - } - } - } - - // Returns a standard ThreadFactory except all threads are daemons - private def newDaemonThreadFactory: ThreadFactory = { - new ThreadFactory { - def newThread(r: Runnable): Thread = { - var t = Executors.defaultThreadFactory.newThread (r) - t.setDaemon (true) - return t - } - } - } - - // Wrapper over newCachedThreadPool - def newDaemonCachedThreadPool: ThreadPoolExecutor = { - var threadPool = - Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory (newDaemonThreadFactory) - - return threadPool - } - - // Wrapper over newFixedThreadPool - def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = { - var threadPool = - Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory (newDaemonThreadFactory) - - return threadPool - } -} diff --git a/src/scala/spark/DfsBroadcast.scala b/src/scala/spark/DfsBroadcast.scala index a249961fd5..7b1ebce851 100644 --- a/src/scala/spark/DfsBroadcast.scala +++ b/src/scala/spark/DfsBroadcast.scala @@ -10,8 +10,8 @@ import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} import spark.compress.lzf.{LZFInputStream, LZFOutputStream} @serializable -class DfsBroadcast[T](@transient var value_ : T, local: Boolean) -extends Broadcast with Logging { +class DfsBroadcast[T](@transient var value_ : T, isLocal: Boolean) +extends Broadcast[T] with Logging { def value = value_ @@ -19,7 +19,7 @@ extends Broadcast with Logging { DfsBroadcast.values.put(uuid, value_) } - if (!local) { + if (!isLocal) { sendBroadcast } @@ -52,6 +52,13 @@ extends Broadcast with Logging { } } +class DfsBroadcastFactory +extends BroadcastFactory { + def initialize (isMaster: Boolean) = DfsBroadcast.initialize + def newBroadcast[T] (value_ : T, isLocal: Boolean) = + new DfsBroadcast[T] (value_, isLocal) +} + private object DfsBroadcast extends Logging { val values = Cache.newKeySpace() diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala index 841ccf7930..8ef5817359 100644 --- a/src/scala/spark/SparkContext.scala +++ b/src/scala/spark/SparkContext.scala @@ -100,7 +100,9 @@ extends Logging { // TODO: Keep around a weak hash map of values to Cached versions? // def broadcast[T](value: T) = new DfsBroadcast(value, isLocal) // def broadcast[T](value: T) = new ChainedBroadcast(value, isLocal) - def broadcast[T](value: T) = new BitTorrentBroadcast(value, isLocal) + // def broadcast[T](value: T) = new BitTorrentBroadcast(value, isLocal) + def broadcast[T](value: T) = + Broadcast.getBroadcastFactory.newBroadcast[T] (value, isLocal) // Stop the SparkContext def stop() { -- cgit v1.2.3