aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2011-05-17 12:41:13 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2011-05-17 12:41:13 -0700
commitfd1d255821bde844af28e897fabd59a715659038 (patch)
treeed3bd4a7009aff50cf92c1df38155a35c2fa578c /core
parent4db50e26c75263b2edae468b0e8a9283b5c2e6f1 (diff)
downloadspark-fd1d255821bde844af28e897fabd59a715659038.tar.gz
spark-fd1d255821bde844af28e897fabd59a715659038.tar.bz2
spark-fd1d255821bde844af28e897fabd59a715659038.zip
Stop objectifying various trackers, caches, etc.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/BitTorrentBroadcast.scala2
-rw-r--r--core/src/main/scala/spark/Cache.scala24
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala (renamed from core/src/main/scala/spark/RDDCache.scala)60
-rw-r--r--core/src/main/scala/spark/ChainedBroadcast.scala2
-rw-r--r--core/src/main/scala/spark/CoGroupedRDD.scala4
-rw-r--r--core/src/main/scala/spark/DAGScheduler.scala11
-rw-r--r--core/src/main/scala/spark/DfsBroadcast.scala2
-rw-r--r--core/src/main/scala/spark/DiskSpillingCache.scala4
-rw-r--r--core/src/main/scala/spark/Executor.scala14
-rw-r--r--core/src/main/scala/spark/JavaSerializer.scala6
-rw-r--r--core/src/main/scala/spark/KryoSerialization.scala10
-rw-r--r--core/src/main/scala/spark/LocalFileShuffle.scala4
-rw-r--r--core/src/main/scala/spark/LocalScheduler.scala6
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala29
-rw-r--r--core/src/main/scala/spark/RDD.scala2
-rw-r--r--core/src/main/scala/spark/Serializer.scala49
-rw-r--r--core/src/main/scala/spark/SerializingCache.scala4
-rw-r--r--core/src/main/scala/spark/ShuffledRDD.scala4
-rw-r--r--core/src/main/scala/spark/SparkContext.scala20
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala36
20 files changed, 153 insertions, 140 deletions
diff --git a/core/src/main/scala/spark/BitTorrentBroadcast.scala b/core/src/main/scala/spark/BitTorrentBroadcast.scala
index 96d3643ffd..2f5d063438 100644
--- a/core/src/main/scala/spark/BitTorrentBroadcast.scala
+++ b/core/src/main/scala/spark/BitTorrentBroadcast.scala
@@ -1037,7 +1037,7 @@ extends BroadcastFactory {
private object BitTorrentBroadcast
extends Logging {
- val values = Cache.newKeySpace()
+ val values = SparkEnv.get.cache.newKeySpace()
var valueToGuideMap = Map[UUID, SourceInfo] ()
diff --git a/core/src/main/scala/spark/Cache.scala b/core/src/main/scala/spark/Cache.scala
index 9887520758..89befae1a4 100644
--- a/core/src/main/scala/spark/Cache.scala
+++ b/core/src/main/scala/spark/Cache.scala
@@ -37,27 +37,3 @@ class KeySpace(cache: Cache, id: Long) {
def get(key: Any): Any = cache.get((id, key))
def put(key: Any, value: Any): Unit = cache.put((id, key), value)
}
-
-
-/**
- * The Cache object maintains a global Cache instance, of the type specified
- * by the spark.cache.class property.
- */
-object Cache {
- private var instance: Cache = null
-
- def initialize() {
- val cacheClass = System.getProperty("spark.cache.class",
- "spark.SoftReferenceCache")
- instance = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
- }
-
- def getInstance(): Cache = {
- if (instance == null) {
- throw new SparkException("Cache.getInstance called before initialize")
- }
- instance
- }
-
- def newKeySpace(): KeySpace = getInstance().newKeySpace()
-}
diff --git a/core/src/main/scala/spark/RDDCache.scala b/core/src/main/scala/spark/CacheTracker.scala
index c5557159a6..8b5c99cf3c 100644
--- a/core/src/main/scala/spark/RDDCache.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -6,22 +6,22 @@ import scala.actors.remote._
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
-sealed trait CacheMessage
-case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheMessage
-case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheMessage
-case class MemoryCacheLost(host: String) extends CacheMessage
-case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage
-case object GetCacheLocations extends CacheMessage
-case object StopCacheTracker extends CacheMessage
+sealed trait CacheTrackerMessage
+case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheTrackerMessage
+case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheTrackerMessage
+case class MemoryCacheLost(host: String) extends CacheTrackerMessage
+case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
+case object GetCacheLocations extends CacheTrackerMessage
+case object StopCacheTracker extends CacheTrackerMessage
-class RDDCacheTracker extends DaemonActor with Logging {
+class CacheTrackerActor extends DaemonActor with Logging {
val locs = new HashMap[Int, Array[List[String]]]
// TODO: Should probably store (String, CacheType) tuples
def act() {
- val port = System.getProperty("spark.master.port", "50501").toInt
+ val port = System.getProperty("spark.master.port").toInt
RemoteActor.alive(port)
- RemoteActor.register('RDDCacheTracker, self)
+ RemoteActor.register('CacheTracker, self)
logInfo("Registered actor on port " + port)
loop {
@@ -60,31 +60,27 @@ class RDDCacheTracker extends DaemonActor with Logging {
}
}
-private object RDDCache extends Logging {
- // Stores map results for various splits locally
- var cache: KeySpace = null
-
- // Remembers which splits are currently being loaded (on worker nodes)
- val loading = new HashSet[(Int, Int)]
-
+class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
// Tracker actor on the master, or remote reference to it on workers
var trackerActor: AbstractActor = null
- var registeredRddIds: HashSet[Int] = null
-
- def initialize(isMaster: Boolean) {
- if (isMaster) {
- val tracker = new RDDCacheTracker
- tracker.start
- trackerActor = tracker
- } else {
- val host = System.getProperty("spark.master.host")
- val port = System.getProperty("spark.master.port").toInt
- trackerActor = RemoteActor.select(Node(host, port), 'RDDCacheTracker)
- }
- registeredRddIds = new HashSet[Int]
- cache = Cache.newKeySpace()
+ if (isMaster) {
+ val tracker = new CacheTrackerActor
+ tracker.start
+ trackerActor = tracker
+ } else {
+ val host = System.getProperty("spark.master.host")
+ val port = System.getProperty("spark.master.port").toInt
+ trackerActor = RemoteActor.select(Node(host, port), 'CacheTracker)
}
+
+ val registeredRddIds = new HashSet[Int]
+
+ // Stores map results for various splits locally
+ val cache = theCache.newKeySpace()
+
+ // Remembers which splits are currently being loaded (on worker nodes)
+ val loading = new HashSet[(Int, Int)]
// Registers an RDD (on master only)
def registerRDD(rddId: Int, numPartitions: Int) {
@@ -102,7 +98,7 @@ private object RDDCache extends Logging {
(trackerActor !? GetCacheLocations) match {
case h: HashMap[Int, Array[List[String]]] => h
case _ => throw new SparkException(
- "Internal error: RDDCache did not reply with a HashMap")
+ "Internal error: CacheTrackerActor did not reply with a HashMap")
}
}
diff --git a/core/src/main/scala/spark/ChainedBroadcast.scala b/core/src/main/scala/spark/ChainedBroadcast.scala
index afd3c0293c..63c79c693e 100644
--- a/core/src/main/scala/spark/ChainedBroadcast.scala
+++ b/core/src/main/scala/spark/ChainedBroadcast.scala
@@ -719,7 +719,7 @@ extends BroadcastFactory {
private object ChainedBroadcast
extends Logging {
- val values = Cache.newKeySpace()
+ val values = SparkEnv.get.cache.newKeySpace()
var valueToGuidePortMap = Map[UUID, Int] ()
diff --git a/core/src/main/scala/spark/CoGroupedRDD.scala b/core/src/main/scala/spark/CoGroupedRDD.scala
index 4c427bd67c..53cae76e3a 100644
--- a/core/src/main/scala/spark/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/CoGroupedRDD.scala
@@ -83,7 +83,7 @@ extends RDD[(K, Seq[Seq[_]])](rdds.first.context) with Logging {
// Read map outputs of shuffle
logInfo("Grabbing map outputs for shuffle ID " + shuffleId)
val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
- val serverUris = MapOutputTracker.getServerUris(shuffleId)
+ val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
for ((serverUri, index) <- serverUris.zipWithIndex) {
splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index
}
@@ -109,4 +109,4 @@ extends RDD[(K, Seq[Seq[_]])](rdds.first.context) with Logging {
}
map.iterator
}
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala
index 2e427dcb0c..048a0faf2f 100644
--- a/core/src/main/scala/spark/DAGScheduler.scala
+++ b/core/src/main/scala/spark/DAGScheduler.scala
@@ -35,12 +35,15 @@ private trait DAGScheduler extends Scheduler with Logging {
var cacheLocs = new HashMap[Int, Array[List[String]]]
+ val cacheTracker = SparkEnv.get.cacheTracker
+ val mapOutputTracker = SparkEnv.get.mapOutputTracker
+
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
cacheLocs(rdd.id)
}
def updateCacheLocs() {
- cacheLocs = RDDCache.getLocationsSnapshot()
+ cacheLocs = cacheTracker.getLocationsSnapshot()
}
def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = {
@@ -56,7 +59,7 @@ private trait DAGScheduler extends Scheduler with Logging {
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = {
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
- RDDCache.registerRDD(rdd.id, rdd.splits.size)
+ cacheTracker.registerRDD(rdd.id, rdd.splits.size)
val id = newStageId()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd))
idToStage(id) = stage
@@ -71,7 +74,7 @@ private trait DAGScheduler extends Scheduler with Logging {
visited += r
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
- RDDCache.registerRDD(r.id, r.splits.size)
+ cacheTracker.registerRDD(r.id, r.splits.size)
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_,_] =>
@@ -187,7 +190,7 @@ private trait DAGScheduler extends Scheduler with Logging {
logInfo(stage + " finished; looking for newly runnable stages")
running -= stage
if (stage.shuffleDep != None) {
- MapOutputTracker.registerMapOutputs(
+ mapOutputTracker.registerMapOutputs(
stage.shuffleDep.get.shuffleId,
stage.outputLocs.map(_.first).toArray)
}
diff --git a/core/src/main/scala/spark/DfsBroadcast.scala b/core/src/main/scala/spark/DfsBroadcast.scala
index 480d6dd9b1..895f55ca22 100644
--- a/core/src/main/scala/spark/DfsBroadcast.scala
+++ b/core/src/main/scala/spark/DfsBroadcast.scala
@@ -61,7 +61,7 @@ extends BroadcastFactory {
private object DfsBroadcast
extends Logging {
- val values = Cache.newKeySpace()
+ val values = SparkEnv.get.cache.newKeySpace()
private var initialized = false
diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala
index 9e52fee69e..80e13a2519 100644
--- a/core/src/main/scala/spark/DiskSpillingCache.scala
+++ b/core/src/main/scala/spark/DiskSpillingCache.scala
@@ -14,7 +14,7 @@ class DiskSpillingCache extends BoundedMemoryCache {
override def get(key: Any): Any = {
synchronized {
- val ser = Serializer.newInstance()
+ val ser = SparkEnv.get.serializer.newInstance()
super.get(key) match {
case bytes: Any => // found in memory
ser.deserialize(bytes.asInstanceOf[Array[Byte]])
@@ -46,7 +46,7 @@ class DiskSpillingCache extends BoundedMemoryCache {
}
override def put(key: Any, value: Any) {
- var ser = Serializer.newInstance()
+ var ser = SparkEnv.get.serializer.newInstance()
super.put(key, ser.serialize(value))
}
diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala
index 98d757e116..a3666fdbae 100644
--- a/core/src/main/scala/spark/Executor.scala
+++ b/core/src/main/scala/spark/Executor.scala
@@ -15,6 +15,7 @@ import mesos.{TaskDescription, TaskState, TaskStatus}
class Executor extends mesos.Executor with Logging {
var classLoader: ClassLoader = null
var threadPool: ExecutorService = null
+ var env: SparkEnv = null
override def init(d: ExecutorDriver, args: ExecutorArgs) {
// Read spark.* system properties from executor arg
@@ -22,19 +23,19 @@ class Executor extends mesos.Executor with Logging {
for ((key, value) <- props)
System.setProperty(key, value)
- // Initialize cache and broadcast system (uses some properties read above)
- Cache.initialize()
- Serializer.initialize()
+ // Initialize Spark environment (using system properties read above)
+ env = SparkEnv.createFromSystemProperties(false)
+ SparkEnv.set(env)
+ // Old stuff that isn't yet using env
Broadcast.initialize(false)
- MapOutputTracker.initialize(false)
- RDDCache.initialize(false)
// Create our ClassLoader (using spark properties) and set it on this thread
classLoader = createClassLoader()
Thread.currentThread.setContextClassLoader(classLoader)
// Start worker thread pool (they will inherit our context ClassLoader)
- threadPool = new ThreadPoolExecutor(1, 128, 600, TimeUnit.SECONDS, new LinkedBlockingQueue[Runnable])
+ threadPool = new ThreadPoolExecutor(
+ 1, 128, 600, TimeUnit.SECONDS, new LinkedBlockingQueue[Runnable])
}
override def launchTask(d: ExecutorDriver, desc: TaskDescription) {
@@ -46,6 +47,7 @@ class Executor extends mesos.Executor with Logging {
def run() = {
logInfo("Running task ID " + taskId)
try {
+ SparkEnv.set(env)
Accumulators.clear
val task = Utils.deserialize[Task[Any]](arg, classLoader)
val value = task.run
diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala
index 8ee3044058..af390d55d8 100644
--- a/core/src/main/scala/spark/JavaSerializer.scala
+++ b/core/src/main/scala/spark/JavaSerializer.scala
@@ -19,7 +19,7 @@ class JavaDeserializationStream(in: InputStream) extends DeserializationStream {
def close() { objIn.close() }
}
-class JavaSerializer extends Serializer {
+class JavaSerializerInstance extends SerializerInstance {
def serialize[T](t: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
val out = outputStream(bos)
@@ -43,6 +43,6 @@ class JavaSerializer extends Serializer {
}
}
-class JavaSerialization extends SerializationStrategy {
- def newSerializer(): Serializer = new JavaSerializer
+class JavaSerializer extends Serializer {
+ def newInstance(): SerializerInstance = new JavaSerializerInstance
}
diff --git a/core/src/main/scala/spark/KryoSerialization.scala b/core/src/main/scala/spark/KryoSerialization.scala
index 54427ecf71..ba34a5452a 100644
--- a/core/src/main/scala/spark/KryoSerialization.scala
+++ b/core/src/main/scala/spark/KryoSerialization.scala
@@ -82,8 +82,8 @@ extends DeserializationStream {
def close() { in.close() }
}
-class KryoSerializer(strat: KryoSerialization) extends Serializer {
- val buf = strat.threadBuf.get()
+class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
+ val buf = ks.threadBuf.get()
def serialize[T](t: T): Array[Byte] = {
buf.writeClassAndObject(t)
@@ -94,7 +94,7 @@ class KryoSerializer(strat: KryoSerialization) extends Serializer {
}
def outputStream(s: OutputStream): SerializationStream = {
- new KryoSerializationStream(strat.kryo, strat.threadByteBuf.get(), s)
+ new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s)
}
def inputStream(s: InputStream): DeserializationStream = {
@@ -107,7 +107,7 @@ trait KryoRegistrator {
def registerClasses(kryo: Kryo): Unit
}
-class KryoSerialization extends SerializationStrategy with Logging {
+class KryoSerializer extends Serializer with Logging {
val kryo = createKryo()
val threadBuf = new ThreadLocal[ObjectBuffer] {
@@ -162,5 +162,5 @@ class KryoSerialization extends SerializationStrategy with Logging {
kryo
}
- def newSerializer(): Serializer = new KryoSerializer(this)
+ def newInstance(): SerializerInstance = new KryoSerializerInstance(this)
}
diff --git a/core/src/main/scala/spark/LocalFileShuffle.scala b/core/src/main/scala/spark/LocalFileShuffle.scala
index 057a7ff43d..ee57ddbf61 100644
--- a/core/src/main/scala/spark/LocalFileShuffle.scala
+++ b/core/src/main/scala/spark/LocalFileShuffle.scala
@@ -47,7 +47,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
case None => createCombiner(v)
}
}
- val ser = Serializer.newInstance()
+ val ser = SparkEnv.get.serializer.newInstance()
for (i <- 0 until numOutputSplits) {
val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i)
val out = ser.outputStream(new FileOutputStream(file))
@@ -70,7 +70,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
return indexes.flatMap((myId: Int) => {
val combiners = new HashMap[K, C]
- val ser = Serializer.newInstance()
+ val ser = SparkEnv.get.serializer.newInstance()
for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) {
for (i <- inputIds) {
val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, myId)
diff --git a/core/src/main/scala/spark/LocalScheduler.scala b/core/src/main/scala/spark/LocalScheduler.scala
index 0287082687..832ab8cca8 100644
--- a/core/src/main/scala/spark/LocalScheduler.scala
+++ b/core/src/main/scala/spark/LocalScheduler.scala
@@ -8,6 +8,8 @@ import java.util.concurrent._
private class LocalScheduler(threads: Int) extends DAGScheduler with Logging {
var threadPool: ExecutorService =
Executors.newFixedThreadPool(threads, DaemonThreadFactory)
+
+ val env = SparkEnv.get
override def start() {}
@@ -18,6 +20,8 @@ private class LocalScheduler(threads: Int) extends DAGScheduler with Logging {
threadPool.submit(new Runnable {
def run() {
logInfo("Running task " + i)
+ // Set the Spark execution environment for the worker thread
+ SparkEnv.set(env)
try {
// Serialize and deserialize the task so that accumulators are
// changed to thread-local ones; this adds a bit of unnecessary
@@ -47,4 +51,4 @@ private class LocalScheduler(threads: Int) extends DAGScheduler with Logging {
override def stop() {}
override def numCores() = threads
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 4334034ecb..d36fbc7703 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -11,10 +11,10 @@ sealed trait MapOutputTrackerMessage
case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
case object StopMapOutputTracker extends MapOutputTrackerMessage
-class MapOutputTracker(serverUris: ConcurrentHashMap[Int, Array[String]])
+class MapOutputTrackerActor(serverUris: ConcurrentHashMap[Int, Array[String]])
extends DaemonActor with Logging {
def act() {
- val port = System.getProperty("spark.master.port", "50501").toInt
+ val port = System.getProperty("spark.master.port").toInt
RemoteActor.alive(port)
RemoteActor.register('MapOutputTracker, self)
logInfo("Registered actor on port " + port)
@@ -32,22 +32,20 @@ extends DaemonActor with Logging {
}
}
-object MapOutputTracker extends Logging {
+class MapOutputTracker(isMaster: Boolean) extends Logging {
var trackerActor: AbstractActor = null
- private val serverUris = new ConcurrentHashMap[Int, Array[String]]
-
- def initialize(isMaster: Boolean) {
- if (isMaster) {
- val tracker = new MapOutputTracker(serverUris)
- tracker.start
- trackerActor = tracker
- } else {
- val host = System.getProperty("spark.master.host")
- val port = System.getProperty("spark.master.port").toInt
- trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker)
- }
+ if (isMaster) {
+ val tracker = new MapOutputTrackerActor(serverUris)
+ tracker.start
+ trackerActor = tracker
+ } else {
+ val host = System.getProperty("spark.master.host")
+ val port = System.getProperty("spark.master.port").toInt
+ trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker)
}
+
+ private val serverUris = new ConcurrentHashMap[Int, Array[String]]
def registerMapOutput(shuffleId: Int, numMaps: Int, mapId: Int, serverUri: String) {
var array = serverUris.get(shuffleId)
@@ -62,7 +60,6 @@ object MapOutputTracker extends Logging {
serverUris.put(shuffleId, Array[String]() ++ locs)
}
-
// Remembers which map output locations are currently being fetched
val fetching = new HashSet[Int]
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 40eb7967ec..6accd5e356 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -43,7 +43,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) {
// Read this RDD; will read from cache if applicable, or otherwise compute
final def iterator(split: Split): Iterator[T] = {
if (shouldCache) {
- RDDCache.getOrCompute[T](this, split)
+ SparkEnv.get.cacheTracker.getOrCompute[T](this, split)
} else {
compute(split)
}
diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala
index a182f6bddc..cfc6d978bc 100644
--- a/core/src/main/scala/spark/Serializer.scala
+++ b/core/src/main/scala/spark/Serializer.scala
@@ -2,39 +2,38 @@ package spark
import java.io.{InputStream, OutputStream}
-trait SerializationStream {
- def writeObject[T](t: T): Unit
- def flush(): Unit
- def close(): Unit
-}
-
-trait DeserializationStream {
- def readObject[T](): T
- def close(): Unit
+/**
+ * A serializer. Because some serialization libraries are not thread safe,
+ * this class is used to create SerializerInstances that do the actual
+ * serialization.
+ */
+trait Serializer {
+ def newInstance(): SerializerInstance
}
-trait Serializer {
+/**
+ * An instance of the serializer, for use by one thread at a time.
+ */
+trait SerializerInstance {
def serialize[T](t: T): Array[Byte]
def deserialize[T](bytes: Array[Byte]): T
def outputStream(s: OutputStream): SerializationStream
def inputStream(s: InputStream): DeserializationStream
}
-trait SerializationStrategy {
- def newSerializer(): Serializer
+/**
+ * A stream for writing serialized objects.
+ */
+trait SerializationStream {
+ def writeObject[T](t: T): Unit
+ def flush(): Unit
+ def close(): Unit
}
-object Serializer {
- var strat: SerializationStrategy = null
-
- def initialize() {
- val cls = System.getProperty("spark.serialization",
- "spark.JavaSerialization")
- strat = Class.forName(cls).newInstance().asInstanceOf[SerializationStrategy]
- }
-
- // Return a serializer ** for use by a single thread **
- def newInstance(): Serializer = {
- strat.newSerializer()
- }
+/**
+ * A stream for reading serialized objects.
+ */
+trait DeserializationStream {
+ def readObject[T](): T
+ def close(): Unit
}
diff --git a/core/src/main/scala/spark/SerializingCache.scala b/core/src/main/scala/spark/SerializingCache.scala
index cbc64736e6..2c1f96a700 100644
--- a/core/src/main/scala/spark/SerializingCache.scala
+++ b/core/src/main/scala/spark/SerializingCache.scala
@@ -10,14 +10,14 @@ class SerializingCache extends Cache with Logging {
val bmc = new BoundedMemoryCache
override def put(key: Any, value: Any) {
- val ser = Serializer.newInstance()
+ val ser = SparkEnv.get.serializer.newInstance()
bmc.put(key, ser.serialize(value))
}
override def get(key: Any): Any = {
val bytes = bmc.get(key)
if (bytes != null) {
- val ser = Serializer.newInstance()
+ val ser = SparkEnv.get.serializer.newInstance()
return ser.deserialize(bytes.asInstanceOf[Array[Byte]])
} else {
return null
diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala
index 683df12019..f730f0580e 100644
--- a/core/src/main/scala/spark/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/ShuffledRDD.scala
@@ -33,7 +33,7 @@ extends RDD[(K, C)](parent.context) {
val shuffleId = dep.shuffleId
val splitId = split.index
val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
- val serverUris = MapOutputTracker.getServerUris(shuffleId)
+ val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
for ((serverUri, index) <- serverUris.zipWithIndex) {
splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index
}
@@ -58,4 +58,4 @@ extends RDD[(K, C)](parent.context) {
}
combiners.iterator
}
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index dc6964e14b..c1807de0ef 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -20,7 +20,13 @@ extends Logging {
System.setProperty("spark.master.host", Utils.localHostName)
if (System.getProperty("spark.master.port") == null)
System.setProperty("spark.master.port", "50501")
+
+ // Create the Spark execution environment (cache, map output tracker, etc)
+ val env = SparkEnv.createFromSystemProperties(true)
+ SparkEnv.set(env)
+ Broadcast.initialize(true)
+ // Create and start the scheduler
private var scheduler: Scheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
@@ -34,16 +40,9 @@ extends Logging {
new MesosScheduler(this, master, frameworkName)
}
}
+ scheduler.start()
private val isLocal = scheduler.isInstanceOf[LocalScheduler]
-
- // Start the scheduler, the cache and the broadcast system
- scheduler.start()
- Cache.initialize()
- Serializer.initialize()
- Broadcast.initialize(true)
- MapOutputTracker.initialize(true)
- RDDCache.initialize(true)
// Methods for creating RDDs
@@ -122,8 +121,9 @@ extends Logging {
scheduler.stop()
scheduler = null
// TODO: Broadcast.stop(), Cache.stop()?
- MapOutputTracker.stop()
- RDDCache.stop()
+ env.mapOutputTracker.stop()
+ env.cacheTracker.stop()
+ SparkEnv.set(null)
}
// Wait for the scheduler to be registered
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
new file mode 100644
index 0000000000..1bfd0172d7
--- /dev/null
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -0,0 +1,36 @@
+package spark
+
+class SparkEnv (
+ val cache: Cache,
+ val serializer: Serializer,
+ val cacheTracker: CacheTracker,
+ val mapOutputTracker: MapOutputTracker
+)
+
+object SparkEnv {
+ private val env = new ThreadLocal[SparkEnv]
+
+ def set(e: SparkEnv) {
+ env.set(e)
+ }
+
+ def get: SparkEnv = {
+ env.get()
+ }
+
+ def createFromSystemProperties(isMaster: Boolean): SparkEnv = {
+ val cacheClass = System.getProperty("spark.cache.class",
+ "spark.SoftReferenceCache")
+ val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
+
+ val serClass = System.getProperty("spark.serializer",
+ "spark.JavaSerializer")
+ val ser = Class.forName(serClass).newInstance().asInstanceOf[Serializer]
+
+ val cacheTracker = new CacheTracker(isMaster, cache)
+
+ val mapOutputTracker = new MapOutputTracker(isMaster)
+
+ new SparkEnv(cache, ser, cacheTracker, mapOutputTracker)
+ }
+}