aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala4
-rw-r--r--core/src/main/scala/spark/BoundedMemoryCache.scala96
-rw-r--r--core/src/main/scala/spark/Cache.scala44
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala179
-rw-r--r--core/src/main/scala/spark/DiskSpillingCache.scala36
-rw-r--r--core/src/main/scala/spark/Executor.scala13
-rw-r--r--core/src/main/scala/spark/JavaSerializer.scala9
-rw-r--r--core/src/main/scala/spark/KryoSerializer.scala15
-rw-r--r--core/src/main/scala/spark/LocalScheduler.scala21
-rw-r--r--core/src/main/scala/spark/MesosScheduler.scala25
-rw-r--r--core/src/main/scala/spark/PipedRDD.scala14
-rw-r--r--core/src/main/scala/spark/RDD.scala7
-rw-r--r--core/src/main/scala/spark/Serializer.scala1
-rw-r--r--core/src/main/scala/spark/SerializingCache.scala8
-rw-r--r--core/src/main/scala/spark/SimpleJob.scala19
-rw-r--r--core/src/main/scala/spark/SoftReferenceCache.scala9
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala15
-rw-r--r--core/src/main/scala/spark/UnionRDD.scala6
-rw-r--r--core/src/main/scala/spark/Utils.scala35
-rw-r--r--core/src/main/scala/spark/WeakReferenceCache.scala14
-rw-r--r--core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala8
-rw-r--r--core/src/main/scala/spark/broadcast/ChainedBroadcast.scala8
-rw-r--r--core/src/main/scala/spark/broadcast/DfsBroadcast.scala6
-rw-r--r--core/src/main/scala/spark/broadcast/TreeBroadcast.scala8
-rw-r--r--core/src/test/scala/spark/BoundedMemoryCacheTest.scala31
-rw-r--r--core/src/test/scala/spark/CacheTrackerSuite.scala97
-rw-r--r--core/src/test/scala/spark/FailureSuite.scala16
-rw-r--r--core/src/test/scala/spark/MesosSchedulerSuite.scala28
-rw-r--r--core/src/test/scala/spark/PipedRDDSuite.scala37
-rw-r--r--core/src/test/scala/spark/UtilsSuite.scala29
-rw-r--r--repl/src/main/scala/spark/repl/Main.scala2
31 files changed, 655 insertions, 185 deletions
diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
index 2e38376499..7084ff97d9 100644
--- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
+++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
@@ -126,6 +126,10 @@ class WPRSerializerInstance extends SerializerInstance {
throw new UnsupportedOperationException()
}
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+ throw new UnsupportedOperationException()
+ }
+
def outputStream(s: OutputStream): SerializationStream = {
new WPRSerializationStream(s)
}
diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala
index e8e50ac360..1162e34ab0 100644
--- a/core/src/main/scala/spark/BoundedMemoryCache.scala
+++ b/core/src/main/scala/spark/BoundedMemoryCache.scala
@@ -9,19 +9,19 @@ import java.util.LinkedHashMap
* some cache entries have pointers to a shared object. Nonetheless, this Cache should work well
* when most of the space is used by arrays of primitives or of simple classes.
*/
-class BoundedMemoryCache extends Cache with Logging {
- private val maxBytes: Long = getMaxBytes()
+class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
logInfo("BoundedMemoryCache.maxBytes = " + maxBytes)
- private var currentBytes = 0L
- private val map = new LinkedHashMap[Any, Entry](32, 0.75f, true)
+ def this() {
+ this(BoundedMemoryCache.getMaxBytes)
+ }
- // An entry in our map; stores a cached object and its size in bytes
- class Entry(val value: Any, val size: Long) {}
+ private var currentBytes = 0L
+ private val map = new LinkedHashMap[(Any, Int), Entry](32, 0.75f, true)
- override def get(key: Any): Any = {
+ override def get(datasetId: Any, partition: Int): Any = {
synchronized {
- val entry = map.get(key)
+ val entry = map.get((datasetId, partition))
if (entry != null) {
entry.value
} else {
@@ -30,46 +30,80 @@ class BoundedMemoryCache extends Cache with Logging {
}
}
- override def put(key: Any, value: Any) {
+ override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
+ val key = (datasetId, partition)
logInfo("Asked to add key " + key)
+ val size = estimateValueSize(key, value)
+ synchronized {
+ if (size > getCapacity) {
+ return CachePutFailure()
+ } else if (ensureFreeSpace(datasetId, size)) {
+ logInfo("Adding key " + key)
+ map.put(key, new Entry(value, size))
+ currentBytes += size
+ logInfo("Number of entries is now " + map.size)
+ return CachePutSuccess(size)
+ } else {
+ logInfo("Didn't add key " + key + " because we would have evicted part of same dataset")
+ return CachePutFailure()
+ }
+ }
+ }
+
+ override def getCapacity: Long = maxBytes
+
+ /**
+ * Estimate sizeOf 'value'
+ */
+ private def estimateValueSize(key: (Any, Int), value: Any) = {
val startTime = System.currentTimeMillis
val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef])
val timeTaken = System.currentTimeMillis - startTime
logInfo("Estimated size for key %s is %d".format(key, size))
logInfo("Size estimation for key %s took %d ms".format(key, timeTaken))
- synchronized {
- ensureFreeSpace(size)
- logInfo("Adding key " + key)
- map.put(key, new Entry(value, size))
- currentBytes += size
- logInfo("Number of entries is now " + map.size)
- }
- }
-
- private def getMaxBytes(): Long = {
- val memoryFractionToUse = System.getProperty(
- "spark.boundedMemoryCache.memoryFraction", "0.66").toDouble
- (Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong
+ size
}
/**
- * Remove least recently used entries from the map until at least space bytes are free. Assumes
+ * Remove least recently used entries from the map until at least space bytes are free, in order
+ * to make space for a partition from the given dataset ID. If this cannot be done without
+ * evicting other data from the same dataset, returns false; otherwise, returns true. Assumes
* that a lock is held on the BoundedMemoryCache.
*/
- private def ensureFreeSpace(space: Long) {
- logInfo("ensureFreeSpace(%d) called with curBytes=%d, maxBytes=%d".format(
- space, currentBytes, maxBytes))
- val iter = map.entrySet.iterator
+ private def ensureFreeSpace(datasetId: Any, space: Long): Boolean = {
+ logInfo("ensureFreeSpace(%s, %d) called with curBytes=%d, maxBytes=%d".format(
+ datasetId, space, currentBytes, maxBytes))
+ val iter = map.entrySet.iterator // Will give entries in LRU order
while (maxBytes - currentBytes < space && iter.hasNext) {
val mapEntry = iter.next()
- dropEntry(mapEntry.getKey, mapEntry.getValue)
+ val (entryDatasetId, entryPartition) = mapEntry.getKey
+ if (entryDatasetId == datasetId) {
+ // Cannot make space without removing part of the same dataset, or a more recently used one
+ return false
+ }
+ reportEntryDropped(entryDatasetId, entryPartition, mapEntry.getValue)
currentBytes -= mapEntry.getValue.size
iter.remove()
}
+ return true
}
- protected def dropEntry(key: Any, entry: Entry) {
- logInfo("Dropping key %s of size %d to make space".format(key, entry.size))
- SparkEnv.get.cacheTracker.dropEntry(key)
+ protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
+ logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
+ SparkEnv.get.cacheTracker.dropEntry(datasetId, partition)
}
}
+
+// An entry in our map; stores a cached object and its size in bytes
+case class Entry(value: Any, size: Long)
+
+object BoundedMemoryCache {
+ /**
+ * Get maximum cache capacity from system configuration
+ */
+ def getMaxBytes: Long = {
+ val memoryFractionToUse = System.getProperty("spark.boundedMemoryCache.memoryFraction", "0.66").toDouble
+ (Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong
+ }
+}
+
diff --git a/core/src/main/scala/spark/Cache.scala b/core/src/main/scala/spark/Cache.scala
index 696fff4e5e..150fe14e2c 100644
--- a/core/src/main/scala/spark/Cache.scala
+++ b/core/src/main/scala/spark/Cache.scala
@@ -1,10 +1,16 @@
package spark
-import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.atomic.AtomicInteger
+
+sealed trait CachePutResponse
+case class CachePutSuccess(size: Long) extends CachePutResponse
+case class CachePutFailure() extends CachePutResponse
/**
* An interface for caches in Spark, to allow for multiple implementations. Caches are used to store
- * both partitions of cached RDDs and broadcast variables on Spark executors.
+ * both partitions of cached RDDs and broadcast variables on Spark executors. Caches are also aware
+ * of which entries are part of the same dataset (for example, partitions in the same RDD). The key
+ * for each value in a cache is a (datasetID, partition) pair.
*
* A single Cache instance gets created on each machine and is shared by all caches (i.e. both the
* RDD split cache and the broadcast variable cache), to enable global replacement policies.
@@ -17,19 +23,41 @@ import java.util.concurrent.atomic.AtomicLong
* keys that are unique across modules.
*/
abstract class Cache {
- private val nextKeySpaceId = new AtomicLong(0)
+ private val nextKeySpaceId = new AtomicInteger(0)
private def newKeySpaceId() = nextKeySpaceId.getAndIncrement()
def newKeySpace() = new KeySpace(this, newKeySpaceId())
- def get(key: Any): Any
- def put(key: Any, value: Any): Unit
+ /**
+ * Get the value for a given (datasetId, partition), or null if it is not
+ * found.
+ */
+ def get(datasetId: Any, partition: Int): Any
+
+ /**
+ * Attempt to put a value in the cache; returns CachePutFailure if this was
+ * not successful (e.g. because the cache replacement policy forbids it), and
+ * CachePutSuccess if successful. If size estimation is available, the cache
+ * implementation should set the size field in CachePutSuccess.
+ */
+ def put(datasetId: Any, partition: Int, value: Any): CachePutResponse
+
+ /**
+ * Report the capacity of the cache partition. By default this just reports
+ * zero. Specific implementations can choose to provide the capacity number.
+ */
+ def getCapacity: Long = 0L
}
/**
* A key namespace in a Cache.
*/
-class KeySpace(cache: Cache, id: Long) {
- def get(key: Any): Any = cache.get((id, key))
- def put(key: Any, value: Any): Unit = cache.put((id, key), value)
+class KeySpace(cache: Cache, val keySpaceId: Int) {
+ def get(datasetId: Any, partition: Int): Any =
+ cache.get((keySpaceId, datasetId), partition)
+
+ def put(datasetId: Any, partition: Int, value: Any): CachePutResponse =
+ cache.put((keySpaceId, datasetId), partition, value)
+
+ def getCapacity: Long = cache.getCapacity
}
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index 5b6eed743f..4867829c17 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -7,16 +7,32 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
sealed trait CacheTrackerMessage
-case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheTrackerMessage
-case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheTrackerMessage
+case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
+ extends CacheTrackerMessage
+case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
+ extends CacheTrackerMessage
case class MemoryCacheLost(host: String) extends CacheTrackerMessage
case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
+case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
+case object GetCacheStatus extends CacheTrackerMessage
case object GetCacheLocations extends CacheTrackerMessage
case object StopCacheTracker extends CacheTrackerMessage
+
class CacheTrackerActor extends DaemonActor with Logging {
- val locs = new HashMap[Int, Array[List[String]]]
+ private val locs = new HashMap[Int, Array[List[String]]]
+
+ /**
+ * A map from the slave's host name to its cache size.
+ */
+ private val slaveCapacity = new HashMap[String, Long]
+ private val slaveUsage = new HashMap[String, Long]
+
// TODO: Should probably store (String, CacheType) tuples
+
+ private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
+ private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
+ private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
def act() {
val port = System.getProperty("spark.master.port").toInt
@@ -26,31 +42,61 @@ class CacheTrackerActor extends DaemonActor with Logging {
loop {
react {
+ case SlaveCacheStarted(host: String, size: Long) =>
+ logInfo("Started slave cache (size %s) on %s".format(
+ Utils.memoryBytesToString(size), host))
+ slaveCapacity.put(host, size)
+ slaveUsage.put(host, 0)
+ reply('OK)
+
case RegisterRDD(rddId: Int, numPartitions: Int) =>
logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
reply('OK)
- case AddedToCache(rddId, partition, host) =>
- logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
+ case AddedToCache(rddId, partition, host, size) =>
+ if (size > 0) {
+ slaveUsage.put(host, getCacheUsage(host) + size)
+ logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
+ rddId, partition, host, Utils.memoryBytesToString(size),
+ Utils.memoryBytesToString(getCacheAvailable(host))))
+ } else {
+ logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
+ }
locs(rddId)(partition) = host :: locs(rddId)(partition)
reply('OK)
- case DroppedFromCache(rddId, partition, host) =>
- logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
+ case DroppedFromCache(rddId, partition, host, size) =>
+ if (size > 0) {
+ logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
+ rddId, partition, host, Utils.memoryBytesToString(size),
+ Utils.memoryBytesToString(getCacheAvailable(host))))
+ slaveUsage.put(host, getCacheUsage(host) - size)
+
+ // Do a sanity check to make sure usage is greater than 0.
+ val usage = getCacheUsage(host)
+ if (usage < 0) {
+ logError("Cache usage on %s is negative (%d)".format(host, usage))
+ }
+ } else {
+ logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
+ }
locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
-
+ reply('OK)
+
case MemoryCacheLost(host) =>
logInfo("Memory cache lost on " + host)
// TODO: Drop host from the memory locations list of all RDDs
case GetCacheLocations =>
logInfo("Asked for current cache locations")
- val locsCopy = new HashMap[Int, Array[List[String]]]
- for ((rddId, array) <- locs) {
- locsCopy(rddId) = array.clone()
- }
- reply(locsCopy)
+ reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())})
+
+ case GetCacheStatus =>
+ val status = slaveCapacity.map { case (host,capacity) =>
+ (host, capacity, getCacheUsage(host))
+ }.toSeq
+ reply(status)
case StopCacheTracker =>
reply('OK)
@@ -60,10 +106,16 @@ class CacheTrackerActor extends DaemonActor with Logging {
}
}
+
class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
// Tracker actor on the master, or remote reference to it on workers
var trackerActor: AbstractActor = null
-
+
+ val registeredRddIds = new HashSet[Int]
+
+ // Stores map results for various splits locally
+ val cache = theCache.newKeySpace()
+
if (isMaster) {
val tracker = new CacheTrackerActor
tracker.start()
@@ -74,10 +126,8 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
trackerActor = RemoteActor.select(Node(host, port), 'CacheTracker)
}
- val registeredRddIds = new HashSet[Int]
-
- // Stores map results for various splits locally
- val cache = theCache.newKeySpace()
+ // Report the cache being started.
+ trackerActor !? SlaveCacheStarted(Utils.getHost, cache.getCapacity)
// Remembers which splits are currently being loaded (on worker nodes)
val loading = new HashSet[(Int, Int)]
@@ -92,65 +142,92 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
}
}
}
-
+
// Get a snapshot of the currently known locations
def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
(trackerActor !? GetCacheLocations) match {
- case h: HashMap[_, _] =>
- h.asInstanceOf[HashMap[Int, Array[List[String]]]]
-
- case _ =>
- throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap")
+ case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]]
+
+ case _ => throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap")
+ }
+ }
+
+ // Get the usage status of slave caches. Each tuple in the returned sequence
+ // is in the form of (host name, capacity, usage).
+ def getCacheStatus(): Seq[(String, Long, Long)] = {
+ (trackerActor !? GetCacheStatus) match {
+ case h: Seq[(String, Long, Long)] => h.asInstanceOf[Seq[(String, Long, Long)]]
+
+ case _ =>
+ throw new SparkException(
+ "Internal error: CacheTrackerActor did not reply with a Seq[Tuple3[String, Long, Long]")
}
}
// Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]): Iterator[T] = {
- val key = (rdd.id, split.index)
- logInfo("CachedRDD partition key is " + key)
- val cachedVal = cache.get(key)
+ logInfo("Looking for RDD partition %d:%d".format(rdd.id, split.index))
+ val cachedVal = cache.get(rdd.id, split.index)
if (cachedVal != null) {
// Split is in cache, so just return its values
logInfo("Found partition in cache!")
return cachedVal.asInstanceOf[Array[T]].iterator
} else {
// Mark the split as loading (unless someone else marks it first)
+ val key = (rdd.id, split.index)
loading.synchronized {
- if (loading.contains(key)) {
- while (loading.contains(key)) {
- try {loading.wait()} catch {case _ =>}
- }
- return cache.get(key).asInstanceOf[Array[T]].iterator
- } else {
- loading.add(key)
+ while (loading.contains(key)) {
+ // Someone else is loading it; let's wait for them
+ try { loading.wait() } catch { case _ => }
}
+ // See whether someone else has successfully loaded it. The main way this would fail
+ // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
+ // partition but we didn't want to make space for it. However, that case is unlikely
+ // because it's unlikely that two threads would work on the same RDD partition. One
+ // downside of the current code is that threads wait serially if this does happen.
+ val cachedVal = cache.get(rdd.id, split.index)
+ if (cachedVal != null) {
+ return cachedVal.asInstanceOf[Array[T]].iterator
+ }
+ // Nobody's loading it and it's not in the cache; let's load it ourselves
+ loading.add(key)
}
// If we got here, we have to load the split
// Tell the master that we're doing so
- val host = System.getProperty("spark.hostname", Utils.localHostName)
- val future = trackerActor !! AddedToCache(rdd.id, split.index, host)
+
// TODO: fetch any remote copy of the split that may be available
- // TODO: also register a listener for when it unloads
logInfo("Computing partition " + split)
- val array = rdd.compute(split).toArray(m)
- cache.put(key, array)
- loading.synchronized {
- loading.remove(key)
- loading.notifyAll()
+ var array: Array[T] = null
+ var putResponse: CachePutResponse = null
+ try {
+ array = rdd.compute(split).toArray(m)
+ putResponse = cache.put(rdd.id, split.index, array)
+ } finally {
+ // Tell other threads that we've finished our attempt to load the key (whether or not
+ // we've actually succeeded to put it in the map)
+ loading.synchronized {
+ loading.remove(key)
+ loading.notifyAll()
+ }
+ }
+
+ putResponse match {
+ case CachePutSuccess(size) => {
+ // Tell the master that we added the entry. Don't return until it
+ // replies so it can properly schedule future tasks that use this RDD.
+ trackerActor !? AddedToCache(rdd.id, split.index, Utils.getHost, size)
+ }
+ case _ => null
}
- future.apply() // Wait for the reply from the cache tracker
return array.iterator
}
}
- // Reports that an entry has been dropped from the cache
- def dropEntry(key: Any) {
- key match {
- case (keySpaceId: Long, (rddId: Int, partition: Int)) =>
- val host = System.getProperty("spark.hostname", Utils.localHostName)
- trackerActor !! DroppedFromCache(rddId, partition, host)
- case _ =>
- logWarning("Unknown key format: %s".format(key))
+ // Called by the Cache to report that an entry has been dropped from it
+ def dropEntry(datasetId: Any, partition: Int) {
+ datasetId match {
+ //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
+ case (cache.keySpaceId, rddId: Int) => trackerActor !! DroppedFromCache(rddId, partition, Utils.getHost)
}
}
diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala
index 157e071c7f..e11466eb64 100644
--- a/core/src/main/scala/spark/DiskSpillingCache.scala
+++ b/core/src/main/scala/spark/DiskSpillingCache.scala
@@ -9,31 +9,31 @@ import java.util.UUID
// TODO: cache into a separate directory using Utils.createTempDir
// TODO: clean up disk cache afterwards
class DiskSpillingCache extends BoundedMemoryCache {
- private val diskMap = new LinkedHashMap[Any, File](32, 0.75f, true)
+ private val diskMap = new LinkedHashMap[(Any, Int), File](32, 0.75f, true)
- override def get(key: Any): Any = {
+ override def get(datasetId: Any, partition: Int): Any = {
synchronized {
val ser = SparkEnv.get.serializer.newInstance()
- super.get(key) match {
+ super.get(datasetId, partition) match {
case bytes: Any => // found in memory
ser.deserialize(bytes.asInstanceOf[Array[Byte]])
- case _ => diskMap.get(key) match {
+ case _ => diskMap.get((datasetId, partition)) match {
case file: Any => // found on disk
try {
val startTime = System.currentTimeMillis
val bytes = new Array[Byte](file.length.toInt)
new FileInputStream(file).read(bytes)
val timeTaken = System.currentTimeMillis - startTime
- logInfo("Reading key %s of size %d bytes from disk took %d ms".format(
- key, file.length, timeTaken))
- super.put(key, bytes)
+ logInfo("Reading key (%s, %d) of size %d bytes from disk took %d ms".format(
+ datasetId, partition, file.length, timeTaken))
+ super.put(datasetId, partition, bytes)
ser.deserialize(bytes.asInstanceOf[Array[Byte]])
} catch {
case e: IOException =>
- logWarning("Failed to read key %s from disk at %s: %s".format(
- key, file.getPath(), e.getMessage()))
- diskMap.remove(key) // remove dead entry
+ logWarning("Failed to read key (%s, %d) from disk at %s: %s".format(
+ datasetId, partition, file.getPath(), e.getMessage()))
+ diskMap.remove((datasetId, partition)) // remove dead entry
null
}
@@ -44,18 +44,18 @@ class DiskSpillingCache extends BoundedMemoryCache {
}
}
- override def put(key: Any, value: Any) {
+ override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
var ser = SparkEnv.get.serializer.newInstance()
- super.put(key, ser.serialize(value))
+ super.put(datasetId, partition, ser.serialize(value))
}
/**
* Spill the given entry to disk. Assumes that a lock is held on the
* DiskSpillingCache. Assumes that entry.value is a byte array.
*/
- override protected def dropEntry(key: Any, entry: Entry) {
- logInfo("Spilling key %s of size %d to make space".format(
- key, entry.size))
+ override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
+ logInfo("Spilling key (%s, %d) of size %d to make space".format(
+ datasetId, partition, entry.size))
val cacheDir = System.getProperty(
"spark.diskSpillingCache.cacheDir",
System.getProperty("java.io.tmpdir"))
@@ -64,11 +64,11 @@ class DiskSpillingCache extends BoundedMemoryCache {
val stream = new FileOutputStream(file)
stream.write(entry.value.asInstanceOf[Array[Byte]])
stream.close()
- diskMap.put(key, file)
+ diskMap.put((datasetId, partition), file)
} catch {
case e: IOException =>
- logWarning("Failed to spill key %s to disk at %s: %s".format(
- key, file.getPath(), e.getMessage()))
+ logWarning("Failed to spill key (%s, %d) to disk at %s: %s".format(
+ datasetId, partition, file.getPath(), e.getMessage()))
// Do nothing and let the entry be discarded
}
}
diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala
index de45137a4f..c795b6c351 100644
--- a/core/src/main/scala/spark/Executor.scala
+++ b/core/src/main/scala/spark/Executor.scala
@@ -65,16 +65,17 @@ class Executor extends org.apache.mesos.Executor with Logging {
extends Runnable {
override def run() = {
val tid = info.getTaskId.getValue
+ SparkEnv.set(env)
+ Thread.currentThread.setContextClassLoader(classLoader)
+ val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + tid)
d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(info.getTaskId)
.setState(TaskState.TASK_RUNNING)
.build())
try {
- SparkEnv.set(env)
- Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear
- val task = Utils.deserialize[Task[Any]](info.getData.toByteArray, classLoader)
+ val task = ser.deserialize[Task[Any]](info.getData.toByteArray, classLoader)
for (gen <- task.generation) {// Update generation if any is set
env.mapOutputTracker.updateGeneration(gen)
}
@@ -84,7 +85,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(info.getTaskId)
.setState(TaskState.TASK_FINISHED)
- .setData(ByteString.copyFrom(Utils.serialize(result)))
+ .setData(ByteString.copyFrom(ser.serialize(result)))
.build())
logInfo("Finished task ID " + tid)
} catch {
@@ -93,7 +94,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(info.getTaskId)
.setState(TaskState.TASK_FAILED)
- .setData(ByteString.copyFrom(Utils.serialize(reason)))
+ .setData(ByteString.copyFrom(ser.serialize(reason)))
.build())
}
case t: Throwable => {
@@ -101,7 +102,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(info.getTaskId)
.setState(TaskState.TASK_FAILED)
- .setData(ByteString.copyFrom(Utils.serialize(reason)))
+ .setData(ByteString.copyFrom(ser.serialize(reason)))
.build())
// TODO: Handle errors in tasks less dramatically
diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala
index e7cd4364ee..80f615eeb0 100644
--- a/core/src/main/scala/spark/JavaSerializer.scala
+++ b/core/src/main/scala/spark/JavaSerializer.scala
@@ -34,6 +34,15 @@ class JavaSerializerInstance extends SerializerInstance {
in.readObject().asInstanceOf[T]
}
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+ val bis = new ByteArrayInputStream(bytes)
+ val ois = new ObjectInputStream(bis) {
+ override def resolveClass(desc: ObjectStreamClass) =
+ Class.forName(desc.getName, false, loader)
+ }
+ return ois.readObject.asInstanceOf[T]
+ }
+
def outputStream(s: OutputStream): SerializationStream = {
new JavaSerializationStream(s)
}
diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 7d25b965d2..5693613d6d 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -9,6 +9,7 @@ import scala.collection.mutable
import com.esotericsoftware.kryo._
import com.esotericsoftware.kryo.{Serializer => KSerializer}
+import com.esotericsoftware.kryo.serialize.ClassSerializer
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
/**
@@ -100,6 +101,14 @@ class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
buf.readClassAndObject(bytes).asInstanceOf[T]
}
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+ val oldClassLoader = ks.kryo.getClassLoader
+ ks.kryo.setClassLoader(loader)
+ val obj = buf.readClassAndObject(bytes).asInstanceOf[T]
+ ks.kryo.setClassLoader(oldClassLoader)
+ obj
+ }
+
def outputStream(s: OutputStream): SerializationStream = {
new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s)
}
@@ -129,6 +138,8 @@ class KryoSerializer extends Serializer with Logging {
}
def createKryo(): Kryo = {
+ // This is used so we can serialize/deserialize objects without a zero-arg
+ // constructor.
val kryo = new KryoReflectionFactorySupport()
// Register some commonly used classes
@@ -150,6 +161,10 @@ class KryoSerializer extends Serializer with Logging {
kryo.register(obj.getClass)
}
+ // Register the following classes for passing closures.
+ kryo.register(classOf[Class[_]], new ClassSerializer(kryo))
+ kryo.setRegistrationOptional(true)
+
// Register some commonly used Scala singleton objects. Because these
// are singletons, we must return the exact same local object when we
// deserialize rather than returning a clone as FieldSerializer would.
diff --git a/core/src/main/scala/spark/LocalScheduler.scala b/core/src/main/scala/spark/LocalScheduler.scala
index 0cbc68ffc5..3910c7b09e 100644
--- a/core/src/main/scala/spark/LocalScheduler.scala
+++ b/core/src/main/scala/spark/LocalScheduler.scala
@@ -38,14 +38,23 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
Accumulators.clear
- val bytes = Utils.serialize(task)
- logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes")
- val deserializedTask = Utils.deserialize[Task[_]](
- bytes, Thread.currentThread.getContextClassLoader)
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val startTime = System.currentTimeMillis
+ val bytes = ser.serialize(task)
+ val timeTaken = System.currentTimeMillis - startTime
+ logInfo("Size of task %d is %d bytes and took %d ms to serialize".format(
+ idInJob, bytes.size, timeTaken))
+ val deserializedTask = ser.deserialize[Task[_]](bytes, currentThread.getContextClassLoader)
val result: Any = deserializedTask.run(attemptId)
+
+ // Serialize and deserialize the result to emulate what the mesos
+ // executor does. This is useful to catch serialization errors early
+ // on in development (so when users move their local Spark programs
+ // to the cluster, they don't get surprised by serialization errors).
+ val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val accumUpdates = Accumulators.values
logInfo("Finished task " + idInJob)
- taskEnded(task, Success, result, accumUpdates)
+ taskEnded(task, Success, resultToReturn, accumUpdates)
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
@@ -55,7 +64,7 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
submitTask(task, idInJob)
} else {
// TODO: Do something nicer here to return all the way to the user
- System.exit(1)
+ taskEnded(task, new ExceptionFailure(t), null, null)
}
}
}
diff --git a/core/src/main/scala/spark/MesosScheduler.scala b/core/src/main/scala/spark/MesosScheduler.scala
index dc58299a1d..a7711e0d35 100644
--- a/core/src/main/scala/spark/MesosScheduler.scala
+++ b/core/src/main/scala/spark/MesosScheduler.scala
@@ -42,7 +42,7 @@ private class MesosScheduler(
// Memory used by each executor (in megabytes)
val EXECUTOR_MEMORY = {
if (System.getenv("SPARK_MEM") != null) {
- memoryStringToMb(System.getenv("SPARK_MEM"))
+ MesosScheduler.memoryStringToMb(System.getenv("SPARK_MEM"))
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
} else {
512
@@ -81,9 +81,7 @@ private class MesosScheduler(
// Sorts jobs in reverse order of run ID for use in our priority queue (so lower IDs run first)
private val jobOrdering = new Ordering[Job] {
- override def compare(j1: Job, j2: Job): Int = {
- return j2.runId - j1.runId
- }
+ override def compare(j1: Job, j2: Job): Int = j2.runId - j1.runId
}
def newJobId(): Int = this.synchronized {
@@ -162,7 +160,7 @@ private class MesosScheduler(
activeJobs(jobId) = myJob
activeJobsQueue += myJob
logInfo("Adding job with ID " + jobId)
- jobTasks(jobId) = new HashSet()
+ jobTasks(jobId) = HashSet.empty[String]
}
driver.reviveOffers();
}
@@ -390,23 +388,26 @@ private class MesosScheduler(
}
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
+}
+object MesosScheduler {
/**
- * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
- * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
+ * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
+ * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
* environment variable.
*/
def memoryStringToMb(str: String): Int = {
val lower = str.toLowerCase
if (lower.endsWith("k")) {
- (lower.substring(0, lower.length-1).toLong / 1024).toInt
+ (lower.substring(0, lower.length - 1).toLong / 1024).toInt
} else if (lower.endsWith("m")) {
- lower.substring(0, lower.length-1).toInt
+ lower.substring(0, lower.length - 1).toInt
} else if (lower.endsWith("g")) {
- lower.substring(0, lower.length-1).toInt * 1024
+ lower.substring(0, lower.length - 1).toInt * 1024
} else if (lower.endsWith("t")) {
- lower.substring(0, lower.length-1).toInt * 1024 * 1024
- } else {// no suffix, so it's just a number in bytes
+ lower.substring(0, lower.length - 1).toInt * 1024 * 1024
+ } else {
+ // no suffix, so it's just a number in bytes
(lower.toLong / 1024 / 1024).toInt
}
}
diff --git a/core/src/main/scala/spark/PipedRDD.scala b/core/src/main/scala/spark/PipedRDD.scala
index 3f993d895a..8a5de3d7e9 100644
--- a/core/src/main/scala/spark/PipedRDD.scala
+++ b/core/src/main/scala/spark/PipedRDD.scala
@@ -3,6 +3,7 @@ package spark
import java.io.PrintWriter
import java.util.StringTokenizer
+import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
@@ -10,8 +11,12 @@ import scala.io.Source
* An RDD that pipes the contents of each parent partition through an external command
* (printing them one per line) and returns the output as a collection of strings.
*/
-class PipedRDD[T: ClassManifest](parent: RDD[T], command: Seq[String])
+class PipedRDD[T: ClassManifest](
+ parent: RDD[T], command: Seq[String], envVars: Map[String, String])
extends RDD[String](parent.context) {
+
+ def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map())
+
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command))
@@ -21,7 +26,12 @@ class PipedRDD[T: ClassManifest](parent: RDD[T], command: Seq[String])
override val dependencies = List(new OneToOneDependency(parent))
override def compute(split: Split): Iterator[String] = {
- val proc = Runtime.getRuntime.exec(command.toArray)
+ val pb = new ProcessBuilder(command)
+ // Add the environmental variables to the process.
+ val currentEnvVars = pb.environment()
+ envVars.foreach { case(variable, value) => currentEnvVars.put(variable, value) }
+
+ val proc = pb.start()
val env = SparkEnv.get
// Start a thread to print the process's stderr to ours
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 1160de5fd1..fa53d9be2c 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -9,8 +9,6 @@ import java.util.Random
import java.util.Date
import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.Map
-import scala.collection.mutable.HashMap
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
@@ -50,7 +48,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
// Methods that must be implemented by subclasses
def splits: Array[Split]
def compute(split: Split): Iterator[T]
- val dependencies: List[Dependency[_]]
+ @transient val dependencies: List[Dependency[_]]
// Optionally overridden by subclasses to specify how they are partitioned
val partitioner: Option[Partitioner] = None
@@ -146,6 +144,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command)
+ def pipe(command: Seq[String], env: Map[String, String]): RDD[String] =
+ new PipedRDD(this, command, env)
+
def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
new MapPartitionsRDD(this, sc.clean(f))
diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala
index 15fca9fcda..2429bbfeb9 100644
--- a/core/src/main/scala/spark/Serializer.scala
+++ b/core/src/main/scala/spark/Serializer.scala
@@ -16,6 +16,7 @@ trait Serializer {
trait SerializerInstance {
def serialize[T](t: T): Array[Byte]
def deserialize[T](bytes: Array[Byte]): T
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T
def outputStream(s: OutputStream): SerializationStream
def inputStream(s: InputStream): DeserializationStream
}
diff --git a/core/src/main/scala/spark/SerializingCache.scala b/core/src/main/scala/spark/SerializingCache.scala
index a74922ec4c..3d192f2403 100644
--- a/core/src/main/scala/spark/SerializingCache.scala
+++ b/core/src/main/scala/spark/SerializingCache.scala
@@ -9,13 +9,13 @@ import java.io._
class SerializingCache extends Cache with Logging {
val bmc = new BoundedMemoryCache
- override def put(key: Any, value: Any) {
+ override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
val ser = SparkEnv.get.serializer.newInstance()
- bmc.put(key, ser.serialize(value))
+ bmc.put(datasetId, partition, ser.serialize(value))
}
- override def get(key: Any): Any = {
- val bytes = bmc.get(key)
+ override def get(datasetId: Any, partition: Int): Any = {
+ val bytes = bmc.get(datasetId, partition)
if (bytes != null) {
val ser = SparkEnv.get.serializer.newInstance()
return ser.deserialize(bytes.asInstanceOf[Array[Byte]])
diff --git a/core/src/main/scala/spark/SimpleJob.scala b/core/src/main/scala/spark/SimpleJob.scala
index 796498cfe4..01c7efff1e 100644
--- a/core/src/main/scala/spark/SimpleJob.scala
+++ b/core/src/main/scala/spark/SimpleJob.scala
@@ -30,6 +30,9 @@ class SimpleJob(
// Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = 4
+ // Serializer for closures and tasks.
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+
val callingThread = Thread.currentThread
val tasks = tasksSeq.toArray
val numTasks = tasks.length
@@ -170,8 +173,14 @@ class SimpleJob(
.setType(Value.Type.SCALAR)
.setScalar(Value.Scalar.newBuilder().setValue(CPUS_PER_TASK).build())
.build()
- val serializedTask = Utils.serialize(task)
- logDebug("Serialized size: " + serializedTask.size)
+
+ val startTime = System.currentTimeMillis
+ val serializedTask = ser.serialize(task)
+ val timeTaken = System.currentTimeMillis - startTime
+
+ logInfo("Size of task %d:%d is %d bytes and took %d ms to serialize by %s"
+ .format(jobId, index, serializedTask.size, timeTaken, ser.getClass.getName))
+
val taskName = "task %d:%d".format(jobId, index)
return Some(TaskInfo.newBuilder()
.setTaskId(taskId)
@@ -209,7 +218,8 @@ class SimpleJob(
tasksFinished += 1
logInfo("Finished TID %s (progress: %d/%d)".format(tid, tasksFinished, numTasks))
// Deserialize task result
- val result = Utils.deserialize[TaskResult[_]](status.getData.toByteArray)
+ val result = ser.deserialize[TaskResult[_]](
+ status.getData.toByteArray, getClass.getClassLoader)
sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
// Mark finished and stop if we've finished all the tasks
finished(index) = true
@@ -231,7 +241,8 @@ class SimpleJob(
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
if (status.getData != null && status.getData.size > 0) {
- val reason = Utils.deserialize[TaskEndReason](status.getData.toByteArray)
+ val reason = ser.deserialize[TaskEndReason](
+ status.getData.toByteArray, getClass.getClassLoader)
reason match {
case fetchFailed: FetchFailed =>
logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri)
diff --git a/core/src/main/scala/spark/SoftReferenceCache.scala b/core/src/main/scala/spark/SoftReferenceCache.scala
index e84aa57efa..ce9370c5d7 100644
--- a/core/src/main/scala/spark/SoftReferenceCache.scala
+++ b/core/src/main/scala/spark/SoftReferenceCache.scala
@@ -8,6 +8,11 @@ import com.google.common.collect.MapMaker
class SoftReferenceCache extends Cache {
val map = new MapMaker().softValues().makeMap[Any, Any]()
- override def get(key: Any): Any = map.get(key)
- override def put(key: Any, value: Any) = map.put(key, value)
+ override def get(datasetId: Any, partition: Int): Any =
+ map.get((datasetId, partition))
+
+ override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
+ map.put((datasetId, partition), value)
+ return CachePutSuccess(0)
+ }
}
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index e2d1562e35..cd752f8b65 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -3,6 +3,7 @@ package spark
class SparkEnv (
val cache: Cache,
val serializer: Serializer,
+ val closureSerializer: Serializer,
val cacheTracker: CacheTracker,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
@@ -27,6 +28,11 @@ object SparkEnv {
val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
+ val closureSerializerClass =
+ System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
+ val closureSerializer =
+ Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer]
+
val cacheTracker = new CacheTracker(isMaster, cache)
val mapOutputTracker = new MapOutputTracker(isMaster)
@@ -38,6 +44,13 @@ object SparkEnv {
val shuffleMgr = new ShuffleManager()
- new SparkEnv(cache, serializer, cacheTracker, mapOutputTracker, shuffleFetcher, shuffleMgr)
+ new SparkEnv(
+ cache,
+ serializer,
+ closureSerializer,
+ cacheTracker,
+ mapOutputTracker,
+ shuffleFetcher,
+ shuffleMgr)
}
}
diff --git a/core/src/main/scala/spark/UnionRDD.scala b/core/src/main/scala/spark/UnionRDD.scala
index 6fded339ee..4c0f255e6b 100644
--- a/core/src/main/scala/spark/UnionRDD.scala
+++ b/core/src/main/scala/spark/UnionRDD.scala
@@ -16,7 +16,7 @@ class UnionSplit[T: ClassManifest](
class UnionRDD[T: ClassManifest](
sc: SparkContext,
- rdds: Seq[RDD[T]])
+ @transient rdds: Seq[RDD[T]])
extends RDD[T](sc)
with Serializable {
@@ -33,7 +33,7 @@ class UnionRDD[T: ClassManifest](
override def splits = splits_
- override val dependencies = {
+ @transient override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for ((rdd, index) <- rdds.zipWithIndex) {
@@ -47,4 +47,4 @@ class UnionRDD[T: ClassManifest](
override def preferredLocations(s: Split): Seq[String] =
s.asInstanceOf[UnionSplit[T]].preferredLocations()
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 58b5fa6bbd..cfd6dc8b2a 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -2,11 +2,11 @@ package spark
import java.io._
import java.net.InetAddress
-import java.util.UUID
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
+import java.util.{Locale, UUID}
/**
* Various utility methods used by Spark.
@@ -157,9 +157,12 @@ object Utils {
/**
* Get the local machine's hostname.
*/
- def localHostName(): String = {
- return InetAddress.getLocalHost().getHostName
- }
+ def localHostName(): String = InetAddress.getLocalHost.getHostName
+
+ /**
+ * Get current host
+ */
+ def getHost = System.getProperty("spark.hostname", localHostName())
/**
* Delete a file or directory and its contents recursively.
@@ -174,4 +177,28 @@ object Utils {
throw new IOException("Failed to delete: " + file)
}
}
+
+ /**
+ * Use unit suffixes (Byte, Kilobyte, Megabyte, Gigabyte, Terabyte and
+ * Petabyte) in order to reduce the number of digits to four or less. For
+ * example, 4,000,000 is returned as 4MB.
+ */
+ def memoryBytesToString(size: Long): String = {
+ val GB = 1L << 30
+ val MB = 1L << 20
+ val KB = 1L << 10
+
+ val (value, unit) = {
+ if (size >= 2*GB) {
+ (size.asInstanceOf[Double] / GB, "GB")
+ } else if (size >= 2*MB) {
+ (size.asInstanceOf[Double] / MB, "MB")
+ } else if (size >= 2*KB) {
+ (size.asInstanceOf[Double] / KB, "KB")
+ } else {
+ (size.asInstanceOf[Double], "B")
+ }
+ }
+ "%.1f%s".formatLocal(Locale.US, value, unit)
+ }
}
diff --git a/core/src/main/scala/spark/WeakReferenceCache.scala b/core/src/main/scala/spark/WeakReferenceCache.scala
deleted file mode 100644
index ddca065454..0000000000
--- a/core/src/main/scala/spark/WeakReferenceCache.scala
+++ /dev/null
@@ -1,14 +0,0 @@
-package spark
-
-import com.google.common.collect.MapMaker
-
-/**
- * An implementation of Cache that uses weak references.
- */
-class WeakReferenceCache extends Cache {
- val map = new MapMaker().weakValues().makeMap[Any, Any]()
-
- override def get(key: Any): Any = map.get(key)
- override def put(key: Any, value: Any) = map.put(key, value)
-}
-
diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
index 6960339bf8..5a873dca3d 100644
--- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
@@ -16,7 +16,7 @@ extends Broadcast[T] with Logging with Serializable {
def value = value_
BitTorrentBroadcast.synchronized {
- BitTorrentBroadcast.values.put(uuid, value_)
+ BitTorrentBroadcast.values.put(uuid, 0, value_)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@@ -130,7 +130,7 @@ extends Broadcast[T] with Logging with Serializable {
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject
BitTorrentBroadcast.synchronized {
- val cachedVal = BitTorrentBroadcast.values.get(uuid)
+ val cachedVal = BitTorrentBroadcast.values.get(uuid, 0)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
@@ -152,12 +152,12 @@ extends Broadcast[T] with Logging with Serializable {
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- BitTorrentBroadcast.values.put(uuid, value_)
+ BitTorrentBroadcast.values.put(uuid, 0, value_)
} else {
// TODO: This part won't work, cause HDFS writing is turned OFF
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
- BitTorrentBroadcast.values.put(uuid, value_)
+ BitTorrentBroadcast.values.put(uuid, 0, value_)
fileIn.close()
}
diff --git a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala
index e33ef78e8a..64da650142 100644
--- a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala
@@ -15,7 +15,7 @@ extends Broadcast[T] with Logging with Serializable {
def value = value_
ChainedBroadcast.synchronized {
- ChainedBroadcast.values.put(uuid, value_)
+ ChainedBroadcast.values.put(uuid, 0, value_)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@@ -101,7 +101,7 @@ extends Broadcast[T] with Logging with Serializable {
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject
ChainedBroadcast.synchronized {
- val cachedVal = ChainedBroadcast.values.get(uuid)
+ val cachedVal = ChainedBroadcast.values.get(uuid, 0)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
@@ -121,11 +121,11 @@ extends Broadcast[T] with Logging with Serializable {
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- ChainedBroadcast.values.put(uuid, value_)
+ ChainedBroadcast.values.put(uuid, 0, value_)
} else {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
- ChainedBroadcast.values.put(uuid, value_)
+ ChainedBroadcast.values.put(uuid, 0, value_)
fileIn.close()
}
diff --git a/core/src/main/scala/spark/broadcast/DfsBroadcast.scala b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala
index 076f18afac..b053e2b62e 100644
--- a/core/src/main/scala/spark/broadcast/DfsBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala
@@ -17,7 +17,7 @@ extends Broadcast[T] with Logging with Serializable {
def value = value_
DfsBroadcast.synchronized {
- DfsBroadcast.values.put(uuid, value_)
+ DfsBroadcast.values.put(uuid, 0, value_)
}
if (!isLocal) {
@@ -34,7 +34,7 @@ extends Broadcast[T] with Logging with Serializable {
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject
DfsBroadcast.synchronized {
- val cachedVal = DfsBroadcast.values.get(uuid)
+ val cachedVal = DfsBroadcast.values.get(uuid, 0)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
@@ -43,7 +43,7 @@ extends Broadcast[T] with Logging with Serializable {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
- DfsBroadcast.values.put(uuid, value_)
+ DfsBroadcast.values.put(uuid, 0, value_)
fileIn.close
val time = (System.nanoTime - start) / 1e9
diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
index 945d8cd8a4..374389def5 100644
--- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
@@ -15,7 +15,7 @@ extends Broadcast[T] with Logging with Serializable {
def value = value_
TreeBroadcast.synchronized {
- TreeBroadcast.values.put(uuid, value_)
+ TreeBroadcast.values.put(uuid, 0, value_)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@@ -104,7 +104,7 @@ extends Broadcast[T] with Logging with Serializable {
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject
TreeBroadcast.synchronized {
- val cachedVal = TreeBroadcast.values.get(uuid)
+ val cachedVal = TreeBroadcast.values.get(uuid, 0)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
@@ -124,11 +124,11 @@ extends Broadcast[T] with Logging with Serializable {
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- TreeBroadcast.values.put(uuid, value_)
+ TreeBroadcast.values.put(uuid, 0, value_)
} else {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
- TreeBroadcast.values.put(uuid, value_)
+ TreeBroadcast.values.put(uuid, 0, value_)
fileIn.close()
}
diff --git a/core/src/test/scala/spark/BoundedMemoryCacheTest.scala b/core/src/test/scala/spark/BoundedMemoryCacheTest.scala
new file mode 100644
index 0000000000..344a733ab3
--- /dev/null
+++ b/core/src/test/scala/spark/BoundedMemoryCacheTest.scala
@@ -0,0 +1,31 @@
+package spark
+
+import org.scalatest.FunSuite
+
+class BoundedMemoryCacheTest extends FunSuite {
+ test("constructor test") {
+ val cache = new BoundedMemoryCache(40)
+ expect(40)(cache.getCapacity)
+ }
+
+ test("caching") {
+ val cache = new BoundedMemoryCache(40) {
+ //TODO sorry about this, but there is not better way how to skip 'cacheTracker.dropEntry'
+ override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
+ logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
+ }
+ }
+ //should be OK
+ expect(CachePutSuccess(30))(cache.put("1", 0, "Meh"))
+
+ //we cannot add this to cache (there is not enough space in cache) & we cannot evict the only value from
+ //cache because it's from the same dataset
+ expect(CachePutFailure())(cache.put("1", 1, "Meh"))
+
+ //should be OK, dataset '1' can be evicted from cache
+ expect(CachePutSuccess(30))(cache.put("2", 0, "Meh"))
+
+ //should fail, cache should obey it's capacity
+ expect(CachePutFailure())(cache.put("3", 0, "Very_long_and_useless_string"))
+ }
+}
diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala
new file mode 100644
index 0000000000..60290d14ca
--- /dev/null
+++ b/core/src/test/scala/spark/CacheTrackerSuite.scala
@@ -0,0 +1,97 @@
+package spark
+
+import org.scalatest.FunSuite
+import collection.mutable.HashMap
+
+class CacheTrackerSuite extends FunSuite {
+
+ test("CacheTrackerActor slave initialization & cache status") {
+ System.setProperty("spark.master.port", "1345")
+ val initialSize = 2L << 20
+
+ val tracker = new CacheTrackerActor
+ tracker.start()
+
+ tracker !? SlaveCacheStarted("host001", initialSize)
+
+ assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 0L)))
+
+ tracker !? StopCacheTracker
+ }
+
+ test("RegisterRDD") {
+ System.setProperty("spark.master.port", "1345")
+ val initialSize = 2L << 20
+
+ val tracker = new CacheTrackerActor
+ tracker.start()
+
+ tracker !? SlaveCacheStarted("host001", initialSize)
+
+ tracker !? RegisterRDD(1, 3)
+ tracker !? RegisterRDD(2, 1)
+
+ assert(getCacheLocations(tracker) == Map(1 -> List(List(), List(), List()), 2 -> List(List())))
+
+ tracker !? StopCacheTracker
+ }
+
+ test("AddedToCache") {
+ System.setProperty("spark.master.port", "1345")
+ val initialSize = 2L << 20
+
+ val tracker = new CacheTrackerActor
+ tracker.start()
+
+ tracker !? SlaveCacheStarted("host001", initialSize)
+
+ tracker !? RegisterRDD(1, 2)
+ tracker !? RegisterRDD(2, 1)
+
+ tracker !? AddedToCache(1, 0, "host001", 2L << 15)
+ tracker !? AddedToCache(1, 1, "host001", 2L << 11)
+ tracker !? AddedToCache(2, 0, "host001", 3L << 10)
+
+ assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
+
+ assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
+
+ tracker !? StopCacheTracker
+ }
+
+ test("DroppedFromCache") {
+ System.setProperty("spark.master.port", "1345")
+ val initialSize = 2L << 20
+
+ val tracker = new CacheTrackerActor
+ tracker.start()
+
+ tracker !? SlaveCacheStarted("host001", initialSize)
+
+ tracker !? RegisterRDD(1, 2)
+ tracker !? RegisterRDD(2, 1)
+
+ tracker !? AddedToCache(1, 0, "host001", 2L << 15)
+ tracker !? AddedToCache(1, 1, "host001", 2L << 11)
+ tracker !? AddedToCache(2, 0, "host001", 3L << 10)
+
+ assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
+ assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
+
+ tracker !? DroppedFromCache(1, 1, "host001", 2L << 11)
+
+ assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 68608L)))
+ assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
+
+ tracker !? StopCacheTracker
+ }
+
+ /**
+ * Helper function to get cacheLocations from CacheTracker
+ */
+ def getCacheLocations(tracker: CacheTrackerActor) = tracker !? GetCacheLocations match {
+ case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]].map {
+ case (i, arr) => (i -> arr.toList)
+ }
+ }
+}
diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala
index ab21f6a6f0..75df4bee09 100644
--- a/core/src/test/scala/spark/FailureSuite.scala
+++ b/core/src/test/scala/spark/FailureSuite.scala
@@ -65,5 +65,21 @@ class FailureSuite extends FunSuite {
FailureSuiteState.clear()
}
+ test("failure because task results are not serializable") {
+ val sc = new SparkContext("local[1,1]", "test")
+ val results = sc.makeRDD(1 to 3).map(x => new NonSerializable)
+
+ val thrown = intercept[spark.SparkException] {
+ results.collect()
+ }
+ assert(thrown.getClass === classOf[spark.SparkException])
+ assert(thrown.getMessage.contains("NotSerializableException"))
+
+ sc.stop()
+ FailureSuiteState.clear()
+ }
+
// TODO: Need to add tests with shuffle fetch failures.
}
+
+
diff --git a/core/src/test/scala/spark/MesosSchedulerSuite.scala b/core/src/test/scala/spark/MesosSchedulerSuite.scala
new file mode 100644
index 0000000000..0e6820cbdc
--- /dev/null
+++ b/core/src/test/scala/spark/MesosSchedulerSuite.scala
@@ -0,0 +1,28 @@
+package spark
+
+import org.scalatest.FunSuite
+
+class MesosSchedulerSuite extends FunSuite {
+ test("memoryStringToMb"){
+
+ assert(MesosScheduler.memoryStringToMb("1") == 0)
+ assert(MesosScheduler.memoryStringToMb("1048575") == 0)
+ assert(MesosScheduler.memoryStringToMb("3145728") == 3)
+
+ assert(MesosScheduler.memoryStringToMb("1024k") == 1)
+ assert(MesosScheduler.memoryStringToMb("5000k") == 4)
+ assert(MesosScheduler.memoryStringToMb("4024k") == MesosScheduler.memoryStringToMb("4024K"))
+
+ assert(MesosScheduler.memoryStringToMb("1024m") == 1024)
+ assert(MesosScheduler.memoryStringToMb("5000m") == 5000)
+ assert(MesosScheduler.memoryStringToMb("4024m") == MesosScheduler.memoryStringToMb("4024M"))
+
+ assert(MesosScheduler.memoryStringToMb("2g") == 2048)
+ assert(MesosScheduler.memoryStringToMb("3g") == MesosScheduler.memoryStringToMb("3G"))
+
+ assert(MesosScheduler.memoryStringToMb("2t") == 2097152)
+ assert(MesosScheduler.memoryStringToMb("3t") == MesosScheduler.memoryStringToMb("3T"))
+
+
+ }
+}
diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala
new file mode 100644
index 0000000000..d5dc2efd91
--- /dev/null
+++ b/core/src/test/scala/spark/PipedRDDSuite.scala
@@ -0,0 +1,37 @@
+package spark
+
+import org.scalatest.FunSuite
+import SparkContext._
+
+class PipedRDDSuite extends FunSuite {
+
+ test("basic pipe") {
+ val sc = new SparkContext("local", "test")
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+
+ val piped = nums.pipe(Seq("cat"))
+
+ val c = piped.collect()
+ println(c.toSeq)
+ assert(c.size === 4)
+ assert(c(0) === "1")
+ assert(c(1) === "2")
+ assert(c(2) === "3")
+ assert(c(3) === "4")
+ sc.stop()
+ }
+
+ test("pipe with env variable") {
+ val sc = new SparkContext("local", "test")
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
+ val c = piped.collect()
+ assert(c.size === 2)
+ assert(c(0) === "LALALA")
+ assert(c(1) === "LALALA")
+ sc.stop()
+ }
+
+}
+
+
diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala
new file mode 100644
index 0000000000..f31251e509
--- /dev/null
+++ b/core/src/test/scala/spark/UtilsSuite.scala
@@ -0,0 +1,29 @@
+package spark
+
+import org.scalatest.FunSuite
+import java.io.{ByteArrayOutputStream, ByteArrayInputStream}
+import util.Random
+
+class UtilsSuite extends FunSuite {
+
+ test("memoryBytesToString") {
+ assert(Utils.memoryBytesToString(10) === "10.0B")
+ assert(Utils.memoryBytesToString(1500) === "1500.0B")
+ assert(Utils.memoryBytesToString(2000000) === "1953.1KB")
+ assert(Utils.memoryBytesToString(2097152) === "2.0MB")
+ assert(Utils.memoryBytesToString(2306867) === "2.2MB")
+ assert(Utils.memoryBytesToString(5368709120L) === "5.0GB")
+ }
+
+ test("copyStream") {
+ //input array initialization
+ val bytes = Array.ofDim[Byte](9000)
+ Random.nextBytes(bytes)
+
+ val os = new ByteArrayOutputStream()
+ Utils.copyStream(new ByteArrayInputStream(bytes), os)
+
+ assert(os.toByteArray.toList.equals(bytes.toList))
+ }
+}
+
diff --git a/repl/src/main/scala/spark/repl/Main.scala b/repl/src/main/scala/spark/repl/Main.scala
index b4a2bb05f9..58809ab646 100644
--- a/repl/src/main/scala/spark/repl/Main.scala
+++ b/repl/src/main/scala/spark/repl/Main.scala
@@ -7,7 +7,7 @@ object Main {
def interp = _interp
- private[repl] def interp_=(i: SparkILoop) { _interp = i }
+ def interp_=(i: SparkILoop) { _interp = i }
def main(args: Array[String]) {
_interp = new SparkILoop