aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-06-09 15:58:07 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-06-09 15:58:07 -0700
commite75b1b5cb480b94f128ae5afe586b2d73be4ae9b (patch)
treee0d9c41b1a5eb505ec8b49c857bbc0d07b688993 /core
parenta96558caa3c0feb20bbf0f3ec367673886fc78c6 (diff)
downloadspark-e75b1b5cb480b94f128ae5afe586b2d73be4ae9b.tar.gz
spark-e75b1b5cb480b94f128ae5afe586b2d73be4ae9b.tar.bz2
spark-e75b1b5cb480b94f128ae5afe586b2d73be4ae9b.zip
Change the default broadcast implementation to a simple HTTP-based
broadcast. Fixes #139.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/Utils.scala6
-rw-r--r--core/src/main/scala/spark/broadcast/Broadcast.scala4
-rw-r--r--core/src/main/scala/spark/broadcast/BroadcastFactory.scala6
-rw-r--r--core/src/main/scala/spark/broadcast/HttpBroadcast.scala110
-rw-r--r--core/src/test/scala/spark/BroadcastSuite.scala23
5 files changed, 144 insertions, 5 deletions
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index cfd6dc8b2a..68ccab24db 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -76,6 +76,12 @@ object Utils {
}
} catch { case e: IOException => ; }
}
+ // Add a shutdown hook to delete the temp dir when the JVM exits
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
+ override def run() {
+ Utils.deleteRecursively(dir)
+ }
+ })
return dir
}
diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala
index cdf05fe5de..06049749a9 100644
--- a/core/src/main/scala/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/spark/broadcast/Broadcast.scala
@@ -33,7 +33,7 @@ object Broadcast extends Logging with Serializable {
def initialize (isMaster__ : Boolean): Unit = synchronized {
if (!initialized) {
val broadcastFactoryClass = System.getProperty(
- "spark.broadcast.factory", "spark.broadcast.DfsBroadcastFactory")
+ "spark.broadcast.factory", "spark.broadcast.HttpBroadcastFactory")
broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
@@ -219,4 +219,4 @@ class SpeedTracker extends Serializable {
}
override def toString = sourceToSpeedMap.toString
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
index 341746d18e..b18908f789 100644
--- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
@@ -7,6 +7,6 @@ package spark.broadcast
* entire Spark job.
*/
trait BroadcastFactory {
- def initialize (isMaster: Boolean): Unit
- def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T]
-} \ No newline at end of file
+ def initialize(isMaster: Boolean): Unit
+ def newBroadcast[T](value_ : T, isLocal: Boolean): Broadcast[T]
+}
diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
new file mode 100644
index 0000000000..4714816591
--- /dev/null
+++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
@@ -0,0 +1,110 @@
+package spark.broadcast
+
+import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
+
+import java.io._
+import java.net._
+import java.util.UUID
+
+import it.unimi.dsi.fastutil.io.FastBufferedInputStream
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+import spark._
+
+class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean)
+extends Broadcast[T] with Logging with Serializable {
+
+ def value = value_
+
+ HttpBroadcast.synchronized {
+ HttpBroadcast.values.put(uuid, 0, value_)
+ }
+
+ if (!isLocal) {
+ HttpBroadcast.write(uuid, value_)
+ }
+
+ // Called by JVM when deserializing an object
+ private def readObject(in: ObjectInputStream): Unit = {
+ in.defaultReadObject()
+ HttpBroadcast.synchronized {
+ val cachedVal = HttpBroadcast.values.get(uuid, 0)
+ if (cachedVal != null) {
+ value_ = cachedVal.asInstanceOf[T]
+ } else {
+ logInfo("Started reading broadcast variable " + uuid)
+ val start = System.nanoTime
+ value_ = HttpBroadcast.read(uuid).asInstanceOf[T]
+ HttpBroadcast.values.put(uuid, 0, value_)
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading broadcast variable " + uuid + " took " + time + " s")
+ }
+ }
+ }
+}
+
+class HttpBroadcastFactory extends BroadcastFactory {
+ def initialize(isMaster: Boolean): Unit = HttpBroadcast.initialize(isMaster)
+ def newBroadcast[T](value_ : T, isLocal: Boolean) = new HttpBroadcast[T](value_, isLocal)
+}
+
+private object HttpBroadcast extends Logging {
+ val values = SparkEnv.get.cache.newKeySpace()
+
+ private var initialized = false
+
+ private var broadcastDir: File = null
+ private var compress: Boolean = false
+ private var bufferSize: Int = 65536
+ private var serverUri: String = null
+ private var server: HttpServer = null
+
+ def initialize(isMaster: Boolean): Unit = {
+ synchronized {
+ if (!initialized) {
+ bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ compress = System.getProperty("spark.compress", "false").toBoolean
+ if (isMaster) {
+ createServer()
+ }
+ serverUri = System.getProperty("spark.httpBroadcast.uri")
+ initialized = true
+ }
+ }
+ }
+
+ private def createServer() {
+ broadcastDir = Utils.createTempDir()
+ server = new HttpServer(broadcastDir)
+ server.start()
+ serverUri = server.uri
+ System.setProperty("spark.httpBroadcast.uri", serverUri)
+ logInfo("Broadcast server started at " + serverUri)
+ }
+
+ def write(uuid: UUID, value: Any) {
+ val file = new File(broadcastDir, "broadcast-" + uuid)
+ val out: OutputStream = if (compress) {
+ new LZFOutputStream(new FileOutputStream(file)) // Does its own buffering
+ } else {
+ new FastBufferedOutputStream(new FileOutputStream(file), bufferSize)
+ }
+ val ser = SparkEnv.get.serializer.newInstance()
+ val serOut = ser.outputStream(out)
+ serOut.writeObject(value)
+ serOut.close()
+ }
+
+ def read(uuid: UUID): Any = {
+ val url = serverUri + "/broadcast-" + uuid
+ var in = if (compress) {
+ new LZFInputStream(new URL(url).openStream()) // Does its own buffering
+ } else {
+ new FastBufferedInputStream(new URL(url).openStream(), bufferSize)
+ }
+ val ser = SparkEnv.get.serializer.newInstance()
+ val serIn = ser.inputStream(in)
+ val obj = serIn.readObject()
+ serIn.close()
+ }
+}
diff --git a/core/src/test/scala/spark/BroadcastSuite.scala b/core/src/test/scala/spark/BroadcastSuite.scala
new file mode 100644
index 0000000000..750703de30
--- /dev/null
+++ b/core/src/test/scala/spark/BroadcastSuite.scala
@@ -0,0 +1,23 @@
+package spark
+
+import org.scalatest.FunSuite
+
+class BroadcastSuite extends FunSuite {
+ test("basic broadcast") {
+ val sc = new SparkContext("local", "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === Set((1, 10), (2, 10)))
+ sc.stop()
+ }
+
+ test("broadcast variables accessed in multiple threads") {
+ val sc = new SparkContext("local[10]", "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
+ sc.stop()
+ }
+}