diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2012-06-03 17:44:04 -0700 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2012-06-03 17:44:04 -0700 |
commit | dbc3c86ae37b815e0c4e2431bc218cb79fa1a4be (patch) | |
tree | 2818303050911934a1104954aa5bdfa67bce144d | |
parent | bd2ab635a784ea031c52421587ffcfd0e7711267 (diff) | |
parent | 1dd7d3dffffff907b47098117e0b09b993000629 (diff) | |
download | spark-dbc3c86ae37b815e0c4e2431bc218cb79fa1a4be.tar.gz spark-dbc3c86ae37b815e0c4e2431bc218cb79fa1a4be.tar.bz2 spark-dbc3c86ae37b815e0c4e2431bc218cb79fa1a4be.zip |
Merge branch 'master' into mesos-0.9
Conflicts:
core/src/main/scala/spark/Executor.scala
31 files changed, 655 insertions, 185 deletions
diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala index 2e38376499..7084ff97d9 100644 --- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala +++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala @@ -126,6 +126,10 @@ class WPRSerializerInstance extends SerializerInstance { throw new UnsupportedOperationException() } + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { + throw new UnsupportedOperationException() + } + def outputStream(s: OutputStream): SerializationStream = { new WPRSerializationStream(s) } diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala index e8e50ac360..1162e34ab0 100644 --- a/core/src/main/scala/spark/BoundedMemoryCache.scala +++ b/core/src/main/scala/spark/BoundedMemoryCache.scala @@ -9,19 +9,19 @@ import java.util.LinkedHashMap * some cache entries have pointers to a shared object. Nonetheless, this Cache should work well * when most of the space is used by arrays of primitives or of simple classes. */ -class BoundedMemoryCache extends Cache with Logging { - private val maxBytes: Long = getMaxBytes() +class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging { logInfo("BoundedMemoryCache.maxBytes = " + maxBytes) - private var currentBytes = 0L - private val map = new LinkedHashMap[Any, Entry](32, 0.75f, true) + def this() { + this(BoundedMemoryCache.getMaxBytes) + } - // An entry in our map; stores a cached object and its size in bytes - class Entry(val value: Any, val size: Long) {} + private var currentBytes = 0L + private val map = new LinkedHashMap[(Any, Int), Entry](32, 0.75f, true) - override def get(key: Any): Any = { + override def get(datasetId: Any, partition: Int): Any = { synchronized { - val entry = map.get(key) + val entry = map.get((datasetId, partition)) if (entry != null) { entry.value } else { @@ -30,46 +30,80 @@ class BoundedMemoryCache extends Cache with Logging { } } - override def put(key: Any, value: Any) { + override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = { + val key = (datasetId, partition) logInfo("Asked to add key " + key) + val size = estimateValueSize(key, value) + synchronized { + if (size > getCapacity) { + return CachePutFailure() + } else if (ensureFreeSpace(datasetId, size)) { + logInfo("Adding key " + key) + map.put(key, new Entry(value, size)) + currentBytes += size + logInfo("Number of entries is now " + map.size) + return CachePutSuccess(size) + } else { + logInfo("Didn't add key " + key + " because we would have evicted part of same dataset") + return CachePutFailure() + } + } + } + + override def getCapacity: Long = maxBytes + + /** + * Estimate sizeOf 'value' + */ + private def estimateValueSize(key: (Any, Int), value: Any) = { val startTime = System.currentTimeMillis val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef]) val timeTaken = System.currentTimeMillis - startTime logInfo("Estimated size for key %s is %d".format(key, size)) logInfo("Size estimation for key %s took %d ms".format(key, timeTaken)) - synchronized { - ensureFreeSpace(size) - logInfo("Adding key " + key) - map.put(key, new Entry(value, size)) - currentBytes += size - logInfo("Number of entries is now " + map.size) - } - } - - private def getMaxBytes(): Long = { - val memoryFractionToUse = System.getProperty( - "spark.boundedMemoryCache.memoryFraction", "0.66").toDouble - (Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong + size } /** - * Remove least recently used entries from the map until at least space bytes are free. Assumes + * Remove least recently used entries from the map until at least space bytes are free, in order + * to make space for a partition from the given dataset ID. If this cannot be done without + * evicting other data from the same dataset, returns false; otherwise, returns true. Assumes * that a lock is held on the BoundedMemoryCache. */ - private def ensureFreeSpace(space: Long) { - logInfo("ensureFreeSpace(%d) called with curBytes=%d, maxBytes=%d".format( - space, currentBytes, maxBytes)) - val iter = map.entrySet.iterator + private def ensureFreeSpace(datasetId: Any, space: Long): Boolean = { + logInfo("ensureFreeSpace(%s, %d) called with curBytes=%d, maxBytes=%d".format( + datasetId, space, currentBytes, maxBytes)) + val iter = map.entrySet.iterator // Will give entries in LRU order while (maxBytes - currentBytes < space && iter.hasNext) { val mapEntry = iter.next() - dropEntry(mapEntry.getKey, mapEntry.getValue) + val (entryDatasetId, entryPartition) = mapEntry.getKey + if (entryDatasetId == datasetId) { + // Cannot make space without removing part of the same dataset, or a more recently used one + return false + } + reportEntryDropped(entryDatasetId, entryPartition, mapEntry.getValue) currentBytes -= mapEntry.getValue.size iter.remove() } + return true } - protected def dropEntry(key: Any, entry: Entry) { - logInfo("Dropping key %s of size %d to make space".format(key, entry.size)) - SparkEnv.get.cacheTracker.dropEntry(key) + protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { + logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size)) + SparkEnv.get.cacheTracker.dropEntry(datasetId, partition) } } + +// An entry in our map; stores a cached object and its size in bytes +case class Entry(value: Any, size: Long) + +object BoundedMemoryCache { + /** + * Get maximum cache capacity from system configuration + */ + def getMaxBytes: Long = { + val memoryFractionToUse = System.getProperty("spark.boundedMemoryCache.memoryFraction", "0.66").toDouble + (Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong + } +} + diff --git a/core/src/main/scala/spark/Cache.scala b/core/src/main/scala/spark/Cache.scala index 696fff4e5e..150fe14e2c 100644 --- a/core/src/main/scala/spark/Cache.scala +++ b/core/src/main/scala/spark/Cache.scala @@ -1,10 +1,16 @@ package spark -import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.atomic.AtomicInteger + +sealed trait CachePutResponse +case class CachePutSuccess(size: Long) extends CachePutResponse +case class CachePutFailure() extends CachePutResponse /** * An interface for caches in Spark, to allow for multiple implementations. Caches are used to store - * both partitions of cached RDDs and broadcast variables on Spark executors. + * both partitions of cached RDDs and broadcast variables on Spark executors. Caches are also aware + * of which entries are part of the same dataset (for example, partitions in the same RDD). The key + * for each value in a cache is a (datasetID, partition) pair. * * A single Cache instance gets created on each machine and is shared by all caches (i.e. both the * RDD split cache and the broadcast variable cache), to enable global replacement policies. @@ -17,19 +23,41 @@ import java.util.concurrent.atomic.AtomicLong * keys that are unique across modules. */ abstract class Cache { - private val nextKeySpaceId = new AtomicLong(0) + private val nextKeySpaceId = new AtomicInteger(0) private def newKeySpaceId() = nextKeySpaceId.getAndIncrement() def newKeySpace() = new KeySpace(this, newKeySpaceId()) - def get(key: Any): Any - def put(key: Any, value: Any): Unit + /** + * Get the value for a given (datasetId, partition), or null if it is not + * found. + */ + def get(datasetId: Any, partition: Int): Any + + /** + * Attempt to put a value in the cache; returns CachePutFailure if this was + * not successful (e.g. because the cache replacement policy forbids it), and + * CachePutSuccess if successful. If size estimation is available, the cache + * implementation should set the size field in CachePutSuccess. + */ + def put(datasetId: Any, partition: Int, value: Any): CachePutResponse + + /** + * Report the capacity of the cache partition. By default this just reports + * zero. Specific implementations can choose to provide the capacity number. + */ + def getCapacity: Long = 0L } /** * A key namespace in a Cache. */ -class KeySpace(cache: Cache, id: Long) { - def get(key: Any): Any = cache.get((id, key)) - def put(key: Any, value: Any): Unit = cache.put((id, key), value) +class KeySpace(cache: Cache, val keySpaceId: Int) { + def get(datasetId: Any, partition: Int): Any = + cache.get((keySpaceId, datasetId), partition) + + def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = + cache.put((keySpaceId, datasetId), partition, value) + + def getCapacity: Long = cache.getCapacity } diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 5b6eed743f..4867829c17 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -7,16 +7,32 @@ import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet sealed trait CacheTrackerMessage -case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheTrackerMessage -case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheTrackerMessage +case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L) + extends CacheTrackerMessage +case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L) + extends CacheTrackerMessage case class MemoryCacheLost(host: String) extends CacheTrackerMessage case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage +case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage +case object GetCacheStatus extends CacheTrackerMessage case object GetCacheLocations extends CacheTrackerMessage case object StopCacheTracker extends CacheTrackerMessage + class CacheTrackerActor extends DaemonActor with Logging { - val locs = new HashMap[Int, Array[List[String]]] + private val locs = new HashMap[Int, Array[List[String]]] + + /** + * A map from the slave's host name to its cache size. + */ + private val slaveCapacity = new HashMap[String, Long] + private val slaveUsage = new HashMap[String, Long] + // TODO: Should probably store (String, CacheType) tuples + + private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) + private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) + private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host) def act() { val port = System.getProperty("spark.master.port").toInt @@ -26,31 +42,61 @@ class CacheTrackerActor extends DaemonActor with Logging { loop { react { + case SlaveCacheStarted(host: String, size: Long) => + logInfo("Started slave cache (size %s) on %s".format( + Utils.memoryBytesToString(size), host)) + slaveCapacity.put(host, size) + slaveUsage.put(host, 0) + reply('OK) + case RegisterRDD(rddId: Int, numPartitions: Int) => logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions") locs(rddId) = Array.fill[List[String]](numPartitions)(Nil) reply('OK) - case AddedToCache(rddId, partition, host) => - logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host)) + case AddedToCache(rddId, partition, host, size) => + if (size > 0) { + slaveUsage.put(host, getCacheUsage(host) + size) + logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format( + rddId, partition, host, Utils.memoryBytesToString(size), + Utils.memoryBytesToString(getCacheAvailable(host)))) + } else { + logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host)) + } locs(rddId)(partition) = host :: locs(rddId)(partition) reply('OK) - case DroppedFromCache(rddId, partition, host) => - logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host)) + case DroppedFromCache(rddId, partition, host, size) => + if (size > 0) { + logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format( + rddId, partition, host, Utils.memoryBytesToString(size), + Utils.memoryBytesToString(getCacheAvailable(host)))) + slaveUsage.put(host, getCacheUsage(host) - size) + + // Do a sanity check to make sure usage is greater than 0. + val usage = getCacheUsage(host) + if (usage < 0) { + logError("Cache usage on %s is negative (%d)".format(host, usage)) + } + } else { + logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host)) + } locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host) - + reply('OK) + case MemoryCacheLost(host) => logInfo("Memory cache lost on " + host) // TODO: Drop host from the memory locations list of all RDDs case GetCacheLocations => logInfo("Asked for current cache locations") - val locsCopy = new HashMap[Int, Array[List[String]]] - for ((rddId, array) <- locs) { - locsCopy(rddId) = array.clone() - } - reply(locsCopy) + reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())}) + + case GetCacheStatus => + val status = slaveCapacity.map { case (host,capacity) => + (host, capacity, getCacheUsage(host)) + }.toSeq + reply(status) case StopCacheTracker => reply('OK) @@ -60,10 +106,16 @@ class CacheTrackerActor extends DaemonActor with Logging { } } + class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging { // Tracker actor on the master, or remote reference to it on workers var trackerActor: AbstractActor = null - + + val registeredRddIds = new HashSet[Int] + + // Stores map results for various splits locally + val cache = theCache.newKeySpace() + if (isMaster) { val tracker = new CacheTrackerActor tracker.start() @@ -74,10 +126,8 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging { trackerActor = RemoteActor.select(Node(host, port), 'CacheTracker) } - val registeredRddIds = new HashSet[Int] - - // Stores map results for various splits locally - val cache = theCache.newKeySpace() + // Report the cache being started. + trackerActor !? SlaveCacheStarted(Utils.getHost, cache.getCapacity) // Remembers which splits are currently being loaded (on worker nodes) val loading = new HashSet[(Int, Int)] @@ -92,65 +142,92 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging { } } } - + // Get a snapshot of the currently known locations def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { (trackerActor !? GetCacheLocations) match { - case h: HashMap[_, _] => - h.asInstanceOf[HashMap[Int, Array[List[String]]]] - - case _ => - throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap") + case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]] + + case _ => throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap") + } + } + + // Get the usage status of slave caches. Each tuple in the returned sequence + // is in the form of (host name, capacity, usage). + def getCacheStatus(): Seq[(String, Long, Long)] = { + (trackerActor !? GetCacheStatus) match { + case h: Seq[(String, Long, Long)] => h.asInstanceOf[Seq[(String, Long, Long)]] + + case _ => + throw new SparkException( + "Internal error: CacheTrackerActor did not reply with a Seq[Tuple3[String, Long, Long]") } } // Gets or computes an RDD split def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]): Iterator[T] = { - val key = (rdd.id, split.index) - logInfo("CachedRDD partition key is " + key) - val cachedVal = cache.get(key) + logInfo("Looking for RDD partition %d:%d".format(rdd.id, split.index)) + val cachedVal = cache.get(rdd.id, split.index) if (cachedVal != null) { // Split is in cache, so just return its values logInfo("Found partition in cache!") return cachedVal.asInstanceOf[Array[T]].iterator } else { // Mark the split as loading (unless someone else marks it first) + val key = (rdd.id, split.index) loading.synchronized { - if (loading.contains(key)) { - while (loading.contains(key)) { - try {loading.wait()} catch {case _ =>} - } - return cache.get(key).asInstanceOf[Array[T]].iterator - } else { - loading.add(key) + while (loading.contains(key)) { + // Someone else is loading it; let's wait for them + try { loading.wait() } catch { case _ => } } + // See whether someone else has successfully loaded it. The main way this would fail + // is for the RDD-level cache eviction policy if someone else has loaded the same RDD + // partition but we didn't want to make space for it. However, that case is unlikely + // because it's unlikely that two threads would work on the same RDD partition. One + // downside of the current code is that threads wait serially if this does happen. + val cachedVal = cache.get(rdd.id, split.index) + if (cachedVal != null) { + return cachedVal.asInstanceOf[Array[T]].iterator + } + // Nobody's loading it and it's not in the cache; let's load it ourselves + loading.add(key) } // If we got here, we have to load the split // Tell the master that we're doing so - val host = System.getProperty("spark.hostname", Utils.localHostName) - val future = trackerActor !! AddedToCache(rdd.id, split.index, host) + // TODO: fetch any remote copy of the split that may be available - // TODO: also register a listener for when it unloads logInfo("Computing partition " + split) - val array = rdd.compute(split).toArray(m) - cache.put(key, array) - loading.synchronized { - loading.remove(key) - loading.notifyAll() + var array: Array[T] = null + var putResponse: CachePutResponse = null + try { + array = rdd.compute(split).toArray(m) + putResponse = cache.put(rdd.id, split.index, array) + } finally { + // Tell other threads that we've finished our attempt to load the key (whether or not + // we've actually succeeded to put it in the map) + loading.synchronized { + loading.remove(key) + loading.notifyAll() + } + } + + putResponse match { + case CachePutSuccess(size) => { + // Tell the master that we added the entry. Don't return until it + // replies so it can properly schedule future tasks that use this RDD. + trackerActor !? AddedToCache(rdd.id, split.index, Utils.getHost, size) + } + case _ => null } - future.apply() // Wait for the reply from the cache tracker return array.iterator } } - // Reports that an entry has been dropped from the cache - def dropEntry(key: Any) { - key match { - case (keySpaceId: Long, (rddId: Int, partition: Int)) => - val host = System.getProperty("spark.hostname", Utils.localHostName) - trackerActor !! DroppedFromCache(rddId, partition, host) - case _ => - logWarning("Unknown key format: %s".format(key)) + // Called by the Cache to report that an entry has been dropped from it + def dropEntry(datasetId: Any, partition: Int) { + datasetId match { + //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here. + case (cache.keySpaceId, rddId: Int) => trackerActor !! DroppedFromCache(rddId, partition, Utils.getHost) } } diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala index 157e071c7f..e11466eb64 100644 --- a/core/src/main/scala/spark/DiskSpillingCache.scala +++ b/core/src/main/scala/spark/DiskSpillingCache.scala @@ -9,31 +9,31 @@ import java.util.UUID // TODO: cache into a separate directory using Utils.createTempDir // TODO: clean up disk cache afterwards class DiskSpillingCache extends BoundedMemoryCache { - private val diskMap = new LinkedHashMap[Any, File](32, 0.75f, true) + private val diskMap = new LinkedHashMap[(Any, Int), File](32, 0.75f, true) - override def get(key: Any): Any = { + override def get(datasetId: Any, partition: Int): Any = { synchronized { val ser = SparkEnv.get.serializer.newInstance() - super.get(key) match { + super.get(datasetId, partition) match { case bytes: Any => // found in memory ser.deserialize(bytes.asInstanceOf[Array[Byte]]) - case _ => diskMap.get(key) match { + case _ => diskMap.get((datasetId, partition)) match { case file: Any => // found on disk try { val startTime = System.currentTimeMillis val bytes = new Array[Byte](file.length.toInt) new FileInputStream(file).read(bytes) val timeTaken = System.currentTimeMillis - startTime - logInfo("Reading key %s of size %d bytes from disk took %d ms".format( - key, file.length, timeTaken)) - super.put(key, bytes) + logInfo("Reading key (%s, %d) of size %d bytes from disk took %d ms".format( + datasetId, partition, file.length, timeTaken)) + super.put(datasetId, partition, bytes) ser.deserialize(bytes.asInstanceOf[Array[Byte]]) } catch { case e: IOException => - logWarning("Failed to read key %s from disk at %s: %s".format( - key, file.getPath(), e.getMessage())) - diskMap.remove(key) // remove dead entry + logWarning("Failed to read key (%s, %d) from disk at %s: %s".format( + datasetId, partition, file.getPath(), e.getMessage())) + diskMap.remove((datasetId, partition)) // remove dead entry null } @@ -44,18 +44,18 @@ class DiskSpillingCache extends BoundedMemoryCache { } } - override def put(key: Any, value: Any) { + override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = { var ser = SparkEnv.get.serializer.newInstance() - super.put(key, ser.serialize(value)) + super.put(datasetId, partition, ser.serialize(value)) } /** * Spill the given entry to disk. Assumes that a lock is held on the * DiskSpillingCache. Assumes that entry.value is a byte array. */ - override protected def dropEntry(key: Any, entry: Entry) { - logInfo("Spilling key %s of size %d to make space".format( - key, entry.size)) + override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { + logInfo("Spilling key (%s, %d) of size %d to make space".format( + datasetId, partition, entry.size)) val cacheDir = System.getProperty( "spark.diskSpillingCache.cacheDir", System.getProperty("java.io.tmpdir")) @@ -64,11 +64,11 @@ class DiskSpillingCache extends BoundedMemoryCache { val stream = new FileOutputStream(file) stream.write(entry.value.asInstanceOf[Array[Byte]]) stream.close() - diskMap.put(key, file) + diskMap.put((datasetId, partition), file) } catch { case e: IOException => - logWarning("Failed to spill key %s to disk at %s: %s".format( - key, file.getPath(), e.getMessage())) + logWarning("Failed to spill key (%s, %d) to disk at %s: %s".format( + datasetId, partition, file.getPath(), e.getMessage())) // Do nothing and let the entry be discarded } } diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala index de45137a4f..c795b6c351 100644 --- a/core/src/main/scala/spark/Executor.scala +++ b/core/src/main/scala/spark/Executor.scala @@ -65,16 +65,17 @@ class Executor extends org.apache.mesos.Executor with Logging { extends Runnable { override def run() = { val tid = info.getTaskId.getValue + SparkEnv.set(env) + Thread.currentThread.setContextClassLoader(classLoader) + val ser = SparkEnv.get.closureSerializer.newInstance() logInfo("Running task ID " + tid) d.sendStatusUpdate(TaskStatus.newBuilder() .setTaskId(info.getTaskId) .setState(TaskState.TASK_RUNNING) .build()) try { - SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(classLoader) Accumulators.clear - val task = Utils.deserialize[Task[Any]](info.getData.toByteArray, classLoader) + val task = ser.deserialize[Task[Any]](info.getData.toByteArray, classLoader) for (gen <- task.generation) {// Update generation if any is set env.mapOutputTracker.updateGeneration(gen) } @@ -84,7 +85,7 @@ class Executor extends org.apache.mesos.Executor with Logging { d.sendStatusUpdate(TaskStatus.newBuilder() .setTaskId(info.getTaskId) .setState(TaskState.TASK_FINISHED) - .setData(ByteString.copyFrom(Utils.serialize(result))) + .setData(ByteString.copyFrom(ser.serialize(result))) .build()) logInfo("Finished task ID " + tid) } catch { @@ -93,7 +94,7 @@ class Executor extends org.apache.mesos.Executor with Logging { d.sendStatusUpdate(TaskStatus.newBuilder() .setTaskId(info.getTaskId) .setState(TaskState.TASK_FAILED) - .setData(ByteString.copyFrom(Utils.serialize(reason))) + .setData(ByteString.copyFrom(ser.serialize(reason))) .build()) } case t: Throwable => { @@ -101,7 +102,7 @@ class Executor extends org.apache.mesos.Executor with Logging { d.sendStatusUpdate(TaskStatus.newBuilder() .setTaskId(info.getTaskId) .setState(TaskState.TASK_FAILED) - .setData(ByteString.copyFrom(Utils.serialize(reason))) + .setData(ByteString.copyFrom(ser.serialize(reason))) .build()) // TODO: Handle errors in tasks less dramatically diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala index e7cd4364ee..80f615eeb0 100644 --- a/core/src/main/scala/spark/JavaSerializer.scala +++ b/core/src/main/scala/spark/JavaSerializer.scala @@ -34,6 +34,15 @@ class JavaSerializerInstance extends SerializerInstance { in.readObject().asInstanceOf[T] } + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { + val bis = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bis) { + override def resolveClass(desc: ObjectStreamClass) = + Class.forName(desc.getName, false, loader) + } + return ois.readObject.asInstanceOf[T] + } + def outputStream(s: OutputStream): SerializationStream = { new JavaSerializationStream(s) } diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 7d25b965d2..5693613d6d 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -9,6 +9,7 @@ import scala.collection.mutable import com.esotericsoftware.kryo._ import com.esotericsoftware.kryo.{Serializer => KSerializer} +import com.esotericsoftware.kryo.serialize.ClassSerializer import de.javakaffee.kryoserializers.KryoReflectionFactorySupport /** @@ -100,6 +101,14 @@ class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { buf.readClassAndObject(bytes).asInstanceOf[T] } + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { + val oldClassLoader = ks.kryo.getClassLoader + ks.kryo.setClassLoader(loader) + val obj = buf.readClassAndObject(bytes).asInstanceOf[T] + ks.kryo.setClassLoader(oldClassLoader) + obj + } + def outputStream(s: OutputStream): SerializationStream = { new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s) } @@ -129,6 +138,8 @@ class KryoSerializer extends Serializer with Logging { } def createKryo(): Kryo = { + // This is used so we can serialize/deserialize objects without a zero-arg + // constructor. val kryo = new KryoReflectionFactorySupport() // Register some commonly used classes @@ -150,6 +161,10 @@ class KryoSerializer extends Serializer with Logging { kryo.register(obj.getClass) } + // Register the following classes for passing closures. + kryo.register(classOf[Class[_]], new ClassSerializer(kryo)) + kryo.setRegistrationOptional(true) + // Register some commonly used Scala singleton objects. Because these // are singletons, we must return the exact same local object when we // deserialize rather than returning a clone as FieldSerializer would. diff --git a/core/src/main/scala/spark/LocalScheduler.scala b/core/src/main/scala/spark/LocalScheduler.scala index 0cbc68ffc5..3910c7b09e 100644 --- a/core/src/main/scala/spark/LocalScheduler.scala +++ b/core/src/main/scala/spark/LocalScheduler.scala @@ -38,14 +38,23 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule // Serialize and deserialize the task so that accumulators are changed to thread-local ones; // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. Accumulators.clear - val bytes = Utils.serialize(task) - logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes") - val deserializedTask = Utils.deserialize[Task[_]]( - bytes, Thread.currentThread.getContextClassLoader) + val ser = SparkEnv.get.closureSerializer.newInstance() + val startTime = System.currentTimeMillis + val bytes = ser.serialize(task) + val timeTaken = System.currentTimeMillis - startTime + logInfo("Size of task %d is %d bytes and took %d ms to serialize".format( + idInJob, bytes.size, timeTaken)) + val deserializedTask = ser.deserialize[Task[_]](bytes, currentThread.getContextClassLoader) val result: Any = deserializedTask.run(attemptId) + + // Serialize and deserialize the result to emulate what the mesos + // executor does. This is useful to catch serialization errors early + // on in development (so when users move their local Spark programs + // to the cluster, they don't get surprised by serialization errors). + val resultToReturn = ser.deserialize[Any](ser.serialize(result)) val accumUpdates = Accumulators.values logInfo("Finished task " + idInJob) - taskEnded(task, Success, result, accumUpdates) + taskEnded(task, Success, resultToReturn, accumUpdates) } catch { case t: Throwable => { logError("Exception in task " + idInJob, t) @@ -55,7 +64,7 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule submitTask(task, idInJob) } else { // TODO: Do something nicer here to return all the way to the user - System.exit(1) + taskEnded(task, new ExceptionFailure(t), null, null) } } } diff --git a/core/src/main/scala/spark/MesosScheduler.scala b/core/src/main/scala/spark/MesosScheduler.scala index dc58299a1d..a7711e0d35 100644 --- a/core/src/main/scala/spark/MesosScheduler.scala +++ b/core/src/main/scala/spark/MesosScheduler.scala @@ -42,7 +42,7 @@ private class MesosScheduler( // Memory used by each executor (in megabytes) val EXECUTOR_MEMORY = { if (System.getenv("SPARK_MEM") != null) { - memoryStringToMb(System.getenv("SPARK_MEM")) + MesosScheduler.memoryStringToMb(System.getenv("SPARK_MEM")) // TODO: Might need to add some extra memory for the non-heap parts of the JVM } else { 512 @@ -81,9 +81,7 @@ private class MesosScheduler( // Sorts jobs in reverse order of run ID for use in our priority queue (so lower IDs run first) private val jobOrdering = new Ordering[Job] { - override def compare(j1: Job, j2: Job): Int = { - return j2.runId - j1.runId - } + override def compare(j1: Job, j2: Job): Int = j2.runId - j1.runId } def newJobId(): Int = this.synchronized { @@ -162,7 +160,7 @@ private class MesosScheduler( activeJobs(jobId) = myJob activeJobsQueue += myJob logInfo("Adding job with ID " + jobId) - jobTasks(jobId) = new HashSet() + jobTasks(jobId) = HashSet.empty[String] } driver.reviveOffers(); } @@ -390,23 +388,26 @@ private class MesosScheduler( } override def offerRescinded(d: SchedulerDriver, o: OfferID) {} +} +object MesosScheduler { /** - * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. - * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM + * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. + * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM * environment variable. */ def memoryStringToMb(str: String): Int = { val lower = str.toLowerCase if (lower.endsWith("k")) { - (lower.substring(0, lower.length-1).toLong / 1024).toInt + (lower.substring(0, lower.length - 1).toLong / 1024).toInt } else if (lower.endsWith("m")) { - lower.substring(0, lower.length-1).toInt + lower.substring(0, lower.length - 1).toInt } else if (lower.endsWith("g")) { - lower.substring(0, lower.length-1).toInt * 1024 + lower.substring(0, lower.length - 1).toInt * 1024 } else if (lower.endsWith("t")) { - lower.substring(0, lower.length-1).toInt * 1024 * 1024 - } else {// no suffix, so it's just a number in bytes + lower.substring(0, lower.length - 1).toInt * 1024 * 1024 + } else { + // no suffix, so it's just a number in bytes (lower.toLong / 1024 / 1024).toInt } } diff --git a/core/src/main/scala/spark/PipedRDD.scala b/core/src/main/scala/spark/PipedRDD.scala index 3f993d895a..8a5de3d7e9 100644 --- a/core/src/main/scala/spark/PipedRDD.scala +++ b/core/src/main/scala/spark/PipedRDD.scala @@ -3,6 +3,7 @@ package spark import java.io.PrintWriter import java.util.StringTokenizer +import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source @@ -10,8 +11,12 @@ import scala.io.Source * An RDD that pipes the contents of each parent partition through an external command * (printing them one per line) and returns the output as a collection of strings. */ -class PipedRDD[T: ClassManifest](parent: RDD[T], command: Seq[String]) +class PipedRDD[T: ClassManifest]( + parent: RDD[T], command: Seq[String], envVars: Map[String, String]) extends RDD[String](parent.context) { + + def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map()) + // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command)) @@ -21,7 +26,12 @@ class PipedRDD[T: ClassManifest](parent: RDD[T], command: Seq[String]) override val dependencies = List(new OneToOneDependency(parent)) override def compute(split: Split): Iterator[String] = { - val proc = Runtime.getRuntime.exec(command.toArray) + val pb = new ProcessBuilder(command) + // Add the environmental variables to the process. + val currentEnvVars = pb.environment() + envVars.foreach { case(variable, value) => currentEnvVars.put(variable, value) } + + val proc = pb.start() val env = SparkEnv.get // Start a thread to print the process's stderr to ours diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 1160de5fd1..fa53d9be2c 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -9,8 +9,6 @@ import java.util.Random import java.util.Date import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.Map -import scala.collection.mutable.HashMap import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable @@ -50,7 +48,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial // Methods that must be implemented by subclasses def splits: Array[Split] def compute(split: Split): Iterator[T] - val dependencies: List[Dependency[_]] + @transient val dependencies: List[Dependency[_]] // Optionally overridden by subclasses to specify how they are partitioned val partitioner: Option[Partitioner] = None @@ -146,6 +144,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command) + def pipe(command: Seq[String], env: Map[String, String]): RDD[String] = + new PipedRDD(this, command, env) + def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] = new MapPartitionsRDD(this, sc.clean(f)) diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala index 15fca9fcda..2429bbfeb9 100644 --- a/core/src/main/scala/spark/Serializer.scala +++ b/core/src/main/scala/spark/Serializer.scala @@ -16,6 +16,7 @@ trait Serializer { trait SerializerInstance { def serialize[T](t: T): Array[Byte] def deserialize[T](bytes: Array[Byte]): T + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T def outputStream(s: OutputStream): SerializationStream def inputStream(s: InputStream): DeserializationStream } diff --git a/core/src/main/scala/spark/SerializingCache.scala b/core/src/main/scala/spark/SerializingCache.scala index a74922ec4c..3d192f2403 100644 --- a/core/src/main/scala/spark/SerializingCache.scala +++ b/core/src/main/scala/spark/SerializingCache.scala @@ -9,13 +9,13 @@ import java.io._ class SerializingCache extends Cache with Logging { val bmc = new BoundedMemoryCache - override def put(key: Any, value: Any) { + override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = { val ser = SparkEnv.get.serializer.newInstance() - bmc.put(key, ser.serialize(value)) + bmc.put(datasetId, partition, ser.serialize(value)) } - override def get(key: Any): Any = { - val bytes = bmc.get(key) + override def get(datasetId: Any, partition: Int): Any = { + val bytes = bmc.get(datasetId, partition) if (bytes != null) { val ser = SparkEnv.get.serializer.newInstance() return ser.deserialize(bytes.asInstanceOf[Array[Byte]]) diff --git a/core/src/main/scala/spark/SimpleJob.scala b/core/src/main/scala/spark/SimpleJob.scala index 796498cfe4..01c7efff1e 100644 --- a/core/src/main/scala/spark/SimpleJob.scala +++ b/core/src/main/scala/spark/SimpleJob.scala @@ -30,6 +30,9 @@ class SimpleJob( // Maximum times a task is allowed to fail before failing the job val MAX_TASK_FAILURES = 4 + // Serializer for closures and tasks. + val ser = SparkEnv.get.closureSerializer.newInstance() + val callingThread = Thread.currentThread val tasks = tasksSeq.toArray val numTasks = tasks.length @@ -170,8 +173,14 @@ class SimpleJob( .setType(Value.Type.SCALAR) .setScalar(Value.Scalar.newBuilder().setValue(CPUS_PER_TASK).build()) .build() - val serializedTask = Utils.serialize(task) - logDebug("Serialized size: " + serializedTask.size) + + val startTime = System.currentTimeMillis + val serializedTask = ser.serialize(task) + val timeTaken = System.currentTimeMillis - startTime + + logInfo("Size of task %d:%d is %d bytes and took %d ms to serialize by %s" + .format(jobId, index, serializedTask.size, timeTaken, ser.getClass.getName)) + val taskName = "task %d:%d".format(jobId, index) return Some(TaskInfo.newBuilder() .setTaskId(taskId) @@ -209,7 +218,8 @@ class SimpleJob( tasksFinished += 1 logInfo("Finished TID %s (progress: %d/%d)".format(tid, tasksFinished, numTasks)) // Deserialize task result - val result = Utils.deserialize[TaskResult[_]](status.getData.toByteArray) + val result = ser.deserialize[TaskResult[_]]( + status.getData.toByteArray, getClass.getClassLoader) sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates) // Mark finished and stop if we've finished all the tasks finished(index) = true @@ -231,7 +241,8 @@ class SimpleJob( // Check if the problem is a map output fetch failure. In that case, this // task will never succeed on any node, so tell the scheduler about it. if (status.getData != null && status.getData.size > 0) { - val reason = Utils.deserialize[TaskEndReason](status.getData.toByteArray) + val reason = ser.deserialize[TaskEndReason]( + status.getData.toByteArray, getClass.getClassLoader) reason match { case fetchFailed: FetchFailed => logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri) diff --git a/core/src/main/scala/spark/SoftReferenceCache.scala b/core/src/main/scala/spark/SoftReferenceCache.scala index e84aa57efa..ce9370c5d7 100644 --- a/core/src/main/scala/spark/SoftReferenceCache.scala +++ b/core/src/main/scala/spark/SoftReferenceCache.scala @@ -8,6 +8,11 @@ import com.google.common.collect.MapMaker class SoftReferenceCache extends Cache { val map = new MapMaker().softValues().makeMap[Any, Any]() - override def get(key: Any): Any = map.get(key) - override def put(key: Any, value: Any) = map.put(key, value) + override def get(datasetId: Any, partition: Int): Any = + map.get((datasetId, partition)) + + override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = { + map.put((datasetId, partition), value) + return CachePutSuccess(0) + } } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index e2d1562e35..cd752f8b65 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -3,6 +3,7 @@ package spark class SparkEnv ( val cache: Cache, val serializer: Serializer, + val closureSerializer: Serializer, val cacheTracker: CacheTracker, val mapOutputTracker: MapOutputTracker, val shuffleFetcher: ShuffleFetcher, @@ -27,6 +28,11 @@ object SparkEnv { val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer") val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer] + val closureSerializerClass = + System.getProperty("spark.closure.serializer", "spark.JavaSerializer") + val closureSerializer = + Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer] + val cacheTracker = new CacheTracker(isMaster, cache) val mapOutputTracker = new MapOutputTracker(isMaster) @@ -38,6 +44,13 @@ object SparkEnv { val shuffleMgr = new ShuffleManager() - new SparkEnv(cache, serializer, cacheTracker, mapOutputTracker, shuffleFetcher, shuffleMgr) + new SparkEnv( + cache, + serializer, + closureSerializer, + cacheTracker, + mapOutputTracker, + shuffleFetcher, + shuffleMgr) } } diff --git a/core/src/main/scala/spark/UnionRDD.scala b/core/src/main/scala/spark/UnionRDD.scala index 6fded339ee..4c0f255e6b 100644 --- a/core/src/main/scala/spark/UnionRDD.scala +++ b/core/src/main/scala/spark/UnionRDD.scala @@ -16,7 +16,7 @@ class UnionSplit[T: ClassManifest]( class UnionRDD[T: ClassManifest]( sc: SparkContext, - rdds: Seq[RDD[T]]) + @transient rdds: Seq[RDD[T]]) extends RDD[T](sc) with Serializable { @@ -33,7 +33,7 @@ class UnionRDD[T: ClassManifest]( override def splits = splits_ - override val dependencies = { + @transient override val dependencies = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for ((rdd, index) <- rdds.zipWithIndex) { @@ -47,4 +47,4 @@ class UnionRDD[T: ClassManifest]( override def preferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() -}
\ No newline at end of file +} diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 58b5fa6bbd..cfd6dc8b2a 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -2,11 +2,11 @@ package spark import java.io._ import java.net.InetAddress -import java.util.UUID import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import scala.collection.mutable.ArrayBuffer import scala.util.Random +import java.util.{Locale, UUID} /** * Various utility methods used by Spark. @@ -157,9 +157,12 @@ object Utils { /** * Get the local machine's hostname. */ - def localHostName(): String = { - return InetAddress.getLocalHost().getHostName - } + def localHostName(): String = InetAddress.getLocalHost.getHostName + + /** + * Get current host + */ + def getHost = System.getProperty("spark.hostname", localHostName()) /** * Delete a file or directory and its contents recursively. @@ -174,4 +177,28 @@ object Utils { throw new IOException("Failed to delete: " + file) } } + + /** + * Use unit suffixes (Byte, Kilobyte, Megabyte, Gigabyte, Terabyte and + * Petabyte) in order to reduce the number of digits to four or less. For + * example, 4,000,000 is returned as 4MB. + */ + def memoryBytesToString(size: Long): String = { + val GB = 1L << 30 + val MB = 1L << 20 + val KB = 1L << 10 + + val (value, unit) = { + if (size >= 2*GB) { + (size.asInstanceOf[Double] / GB, "GB") + } else if (size >= 2*MB) { + (size.asInstanceOf[Double] / MB, "MB") + } else if (size >= 2*KB) { + (size.asInstanceOf[Double] / KB, "KB") + } else { + (size.asInstanceOf[Double], "B") + } + } + "%.1f%s".formatLocal(Locale.US, value, unit) + } } diff --git a/core/src/main/scala/spark/WeakReferenceCache.scala b/core/src/main/scala/spark/WeakReferenceCache.scala deleted file mode 100644 index ddca065454..0000000000 --- a/core/src/main/scala/spark/WeakReferenceCache.scala +++ /dev/null @@ -1,14 +0,0 @@ -package spark - -import com.google.common.collect.MapMaker - -/** - * An implementation of Cache that uses weak references. - */ -class WeakReferenceCache extends Cache { - val map = new MapMaker().weakValues().makeMap[Any, Any]() - - override def get(key: Any): Any = map.get(key) - override def put(key: Any, value: Any) = map.put(key, value) -} - diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala index 6960339bf8..5a873dca3d 100644 --- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala @@ -16,7 +16,7 @@ extends Broadcast[T] with Logging with Serializable { def value = value_ BitTorrentBroadcast.synchronized { - BitTorrentBroadcast.values.put(uuid, value_) + BitTorrentBroadcast.values.put(uuid, 0, value_) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @@ -130,7 +130,7 @@ extends Broadcast[T] with Logging with Serializable { private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject BitTorrentBroadcast.synchronized { - val cachedVal = BitTorrentBroadcast.values.get(uuid) + val cachedVal = BitTorrentBroadcast.values.get(uuid, 0) if (cachedVal != null) { value_ = cachedVal.asInstanceOf[T] @@ -152,12 +152,12 @@ extends Broadcast[T] with Logging with Serializable { // If does not succeed, then get from HDFS copy if (receptionSucceeded) { value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - BitTorrentBroadcast.values.put(uuid, value_) + BitTorrentBroadcast.values.put(uuid, 0, value_) } else { // TODO: This part won't work, cause HDFS writing is turned OFF val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) value_ = fileIn.readObject.asInstanceOf[T] - BitTorrentBroadcast.values.put(uuid, value_) + BitTorrentBroadcast.values.put(uuid, 0, value_) fileIn.close() } diff --git a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala index e33ef78e8a..64da650142 100644 --- a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala @@ -15,7 +15,7 @@ extends Broadcast[T] with Logging with Serializable { def value = value_ ChainedBroadcast.synchronized { - ChainedBroadcast.values.put(uuid, value_) + ChainedBroadcast.values.put(uuid, 0, value_) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @@ -101,7 +101,7 @@ extends Broadcast[T] with Logging with Serializable { private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject ChainedBroadcast.synchronized { - val cachedVal = ChainedBroadcast.values.get(uuid) + val cachedVal = ChainedBroadcast.values.get(uuid, 0) if (cachedVal != null) { value_ = cachedVal.asInstanceOf[T] } else { @@ -121,11 +121,11 @@ extends Broadcast[T] with Logging with Serializable { // If does not succeed, then get from HDFS copy if (receptionSucceeded) { value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - ChainedBroadcast.values.put(uuid, value_) + ChainedBroadcast.values.put(uuid, 0, value_) } else { val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) value_ = fileIn.readObject.asInstanceOf[T] - ChainedBroadcast.values.put(uuid, value_) + ChainedBroadcast.values.put(uuid, 0, value_) fileIn.close() } diff --git a/core/src/main/scala/spark/broadcast/DfsBroadcast.scala b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala index 076f18afac..b053e2b62e 100644 --- a/core/src/main/scala/spark/broadcast/DfsBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala @@ -17,7 +17,7 @@ extends Broadcast[T] with Logging with Serializable { def value = value_ DfsBroadcast.synchronized { - DfsBroadcast.values.put(uuid, value_) + DfsBroadcast.values.put(uuid, 0, value_) } if (!isLocal) { @@ -34,7 +34,7 @@ extends Broadcast[T] with Logging with Serializable { private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject DfsBroadcast.synchronized { - val cachedVal = DfsBroadcast.values.get(uuid) + val cachedVal = DfsBroadcast.values.get(uuid, 0) if (cachedVal != null) { value_ = cachedVal.asInstanceOf[T] } else { @@ -43,7 +43,7 @@ extends Broadcast[T] with Logging with Serializable { val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) value_ = fileIn.readObject.asInstanceOf[T] - DfsBroadcast.values.put(uuid, value_) + DfsBroadcast.values.put(uuid, 0, value_) fileIn.close val time = (System.nanoTime - start) / 1e9 diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala index 945d8cd8a4..374389def5 100644 --- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala @@ -15,7 +15,7 @@ extends Broadcast[T] with Logging with Serializable { def value = value_ TreeBroadcast.synchronized { - TreeBroadcast.values.put(uuid, value_) + TreeBroadcast.values.put(uuid, 0, value_) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @@ -104,7 +104,7 @@ extends Broadcast[T] with Logging with Serializable { private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject TreeBroadcast.synchronized { - val cachedVal = TreeBroadcast.values.get(uuid) + val cachedVal = TreeBroadcast.values.get(uuid, 0) if (cachedVal != null) { value_ = cachedVal.asInstanceOf[T] } else { @@ -124,11 +124,11 @@ extends Broadcast[T] with Logging with Serializable { // If does not succeed, then get from HDFS copy if (receptionSucceeded) { value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - TreeBroadcast.values.put(uuid, value_) + TreeBroadcast.values.put(uuid, 0, value_) } else { val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) value_ = fileIn.readObject.asInstanceOf[T] - TreeBroadcast.values.put(uuid, value_) + TreeBroadcast.values.put(uuid, 0, value_) fileIn.close() } diff --git a/core/src/test/scala/spark/BoundedMemoryCacheTest.scala b/core/src/test/scala/spark/BoundedMemoryCacheTest.scala new file mode 100644 index 0000000000..344a733ab3 --- /dev/null +++ b/core/src/test/scala/spark/BoundedMemoryCacheTest.scala @@ -0,0 +1,31 @@ +package spark + +import org.scalatest.FunSuite + +class BoundedMemoryCacheTest extends FunSuite { + test("constructor test") { + val cache = new BoundedMemoryCache(40) + expect(40)(cache.getCapacity) + } + + test("caching") { + val cache = new BoundedMemoryCache(40) { + //TODO sorry about this, but there is not better way how to skip 'cacheTracker.dropEntry' + override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { + logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size)) + } + } + //should be OK + expect(CachePutSuccess(30))(cache.put("1", 0, "Meh")) + + //we cannot add this to cache (there is not enough space in cache) & we cannot evict the only value from + //cache because it's from the same dataset + expect(CachePutFailure())(cache.put("1", 1, "Meh")) + + //should be OK, dataset '1' can be evicted from cache + expect(CachePutSuccess(30))(cache.put("2", 0, "Meh")) + + //should fail, cache should obey it's capacity + expect(CachePutFailure())(cache.put("3", 0, "Very_long_and_useless_string")) + } +} diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala new file mode 100644 index 0000000000..60290d14ca --- /dev/null +++ b/core/src/test/scala/spark/CacheTrackerSuite.scala @@ -0,0 +1,97 @@ +package spark + +import org.scalatest.FunSuite +import collection.mutable.HashMap + +class CacheTrackerSuite extends FunSuite { + + test("CacheTrackerActor slave initialization & cache status") { + System.setProperty("spark.master.port", "1345") + val initialSize = 2L << 20 + + val tracker = new CacheTrackerActor + tracker.start() + + tracker !? SlaveCacheStarted("host001", initialSize) + + assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 0L))) + + tracker !? StopCacheTracker + } + + test("RegisterRDD") { + System.setProperty("spark.master.port", "1345") + val initialSize = 2L << 20 + + val tracker = new CacheTrackerActor + tracker.start() + + tracker !? SlaveCacheStarted("host001", initialSize) + + tracker !? RegisterRDD(1, 3) + tracker !? RegisterRDD(2, 1) + + assert(getCacheLocations(tracker) == Map(1 -> List(List(), List(), List()), 2 -> List(List()))) + + tracker !? StopCacheTracker + } + + test("AddedToCache") { + System.setProperty("spark.master.port", "1345") + val initialSize = 2L << 20 + + val tracker = new CacheTrackerActor + tracker.start() + + tracker !? SlaveCacheStarted("host001", initialSize) + + tracker !? RegisterRDD(1, 2) + tracker !? RegisterRDD(2, 1) + + tracker !? AddedToCache(1, 0, "host001", 2L << 15) + tracker !? AddedToCache(1, 1, "host001", 2L << 11) + tracker !? AddedToCache(2, 0, "host001", 3L << 10) + + assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L))) + + assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) + + tracker !? StopCacheTracker + } + + test("DroppedFromCache") { + System.setProperty("spark.master.port", "1345") + val initialSize = 2L << 20 + + val tracker = new CacheTrackerActor + tracker.start() + + tracker !? SlaveCacheStarted("host001", initialSize) + + tracker !? RegisterRDD(1, 2) + tracker !? RegisterRDD(2, 1) + + tracker !? AddedToCache(1, 0, "host001", 2L << 15) + tracker !? AddedToCache(1, 1, "host001", 2L << 11) + tracker !? AddedToCache(2, 0, "host001", 3L << 10) + + assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L))) + assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) + + tracker !? DroppedFromCache(1, 1, "host001", 2L << 11) + + assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 68608L))) + assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"),List()), 2 -> List(List("host001")))) + + tracker !? StopCacheTracker + } + + /** + * Helper function to get cacheLocations from CacheTracker + */ + def getCacheLocations(tracker: CacheTrackerActor) = tracker !? GetCacheLocations match { + case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]].map { + case (i, arr) => (i -> arr.toList) + } + } +} diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala index ab21f6a6f0..75df4bee09 100644 --- a/core/src/test/scala/spark/FailureSuite.scala +++ b/core/src/test/scala/spark/FailureSuite.scala @@ -65,5 +65,21 @@ class FailureSuite extends FunSuite { FailureSuiteState.clear() } + test("failure because task results are not serializable") { + val sc = new SparkContext("local[1,1]", "test") + val results = sc.makeRDD(1 to 3).map(x => new NonSerializable) + + val thrown = intercept[spark.SparkException] { + results.collect() + } + assert(thrown.getClass === classOf[spark.SparkException]) + assert(thrown.getMessage.contains("NotSerializableException")) + + sc.stop() + FailureSuiteState.clear() + } + // TODO: Need to add tests with shuffle fetch failures. } + + diff --git a/core/src/test/scala/spark/MesosSchedulerSuite.scala b/core/src/test/scala/spark/MesosSchedulerSuite.scala new file mode 100644 index 0000000000..0e6820cbdc --- /dev/null +++ b/core/src/test/scala/spark/MesosSchedulerSuite.scala @@ -0,0 +1,28 @@ +package spark + +import org.scalatest.FunSuite + +class MesosSchedulerSuite extends FunSuite { + test("memoryStringToMb"){ + + assert(MesosScheduler.memoryStringToMb("1") == 0) + assert(MesosScheduler.memoryStringToMb("1048575") == 0) + assert(MesosScheduler.memoryStringToMb("3145728") == 3) + + assert(MesosScheduler.memoryStringToMb("1024k") == 1) + assert(MesosScheduler.memoryStringToMb("5000k") == 4) + assert(MesosScheduler.memoryStringToMb("4024k") == MesosScheduler.memoryStringToMb("4024K")) + + assert(MesosScheduler.memoryStringToMb("1024m") == 1024) + assert(MesosScheduler.memoryStringToMb("5000m") == 5000) + assert(MesosScheduler.memoryStringToMb("4024m") == MesosScheduler.memoryStringToMb("4024M")) + + assert(MesosScheduler.memoryStringToMb("2g") == 2048) + assert(MesosScheduler.memoryStringToMb("3g") == MesosScheduler.memoryStringToMb("3G")) + + assert(MesosScheduler.memoryStringToMb("2t") == 2097152) + assert(MesosScheduler.memoryStringToMb("3t") == MesosScheduler.memoryStringToMb("3T")) + + + } +} diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala new file mode 100644 index 0000000000..d5dc2efd91 --- /dev/null +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -0,0 +1,37 @@ +package spark + +import org.scalatest.FunSuite +import SparkContext._ + +class PipedRDDSuite extends FunSuite { + + test("basic pipe") { + val sc = new SparkContext("local", "test") + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + + val piped = nums.pipe(Seq("cat")) + + val c = piped.collect() + println(c.toSeq) + assert(c.size === 4) + assert(c(0) === "1") + assert(c(1) === "2") + assert(c(2) === "3") + assert(c(3) === "4") + sc.stop() + } + + test("pipe with env variable") { + val sc = new SparkContext("local", "test") + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) + val c = piped.collect() + assert(c.size === 2) + assert(c(0) === "LALALA") + assert(c(1) === "LALALA") + sc.stop() + } + +} + + diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala new file mode 100644 index 0000000000..f31251e509 --- /dev/null +++ b/core/src/test/scala/spark/UtilsSuite.scala @@ -0,0 +1,29 @@ +package spark + +import org.scalatest.FunSuite +import java.io.{ByteArrayOutputStream, ByteArrayInputStream} +import util.Random + +class UtilsSuite extends FunSuite { + + test("memoryBytesToString") { + assert(Utils.memoryBytesToString(10) === "10.0B") + assert(Utils.memoryBytesToString(1500) === "1500.0B") + assert(Utils.memoryBytesToString(2000000) === "1953.1KB") + assert(Utils.memoryBytesToString(2097152) === "2.0MB") + assert(Utils.memoryBytesToString(2306867) === "2.2MB") + assert(Utils.memoryBytesToString(5368709120L) === "5.0GB") + } + + test("copyStream") { + //input array initialization + val bytes = Array.ofDim[Byte](9000) + Random.nextBytes(bytes) + + val os = new ByteArrayOutputStream() + Utils.copyStream(new ByteArrayInputStream(bytes), os) + + assert(os.toByteArray.toList.equals(bytes.toList)) + } +} + diff --git a/repl/src/main/scala/spark/repl/Main.scala b/repl/src/main/scala/spark/repl/Main.scala index b4a2bb05f9..58809ab646 100644 --- a/repl/src/main/scala/spark/repl/Main.scala +++ b/repl/src/main/scala/spark/repl/Main.scala @@ -7,7 +7,7 @@ object Main { def interp = _interp - private[repl] def interp_=(i: SparkILoop) { _interp = i } + def interp_=(i: SparkILoop) { _interp = i } def main(args: Array[String]) { _interp = new SparkILoop |