aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@cs.berkeley.edu>2013-05-03 01:02:32 -0700
committerReynold Xin <rxin@cs.berkeley.edu>2013-05-03 01:02:32 -0700
commit93091f6936262a4006d875bf69b3f8c31c291617 (patch)
tree64c5fb0b9250f91723aadbaf14113ea6717d3d42 /core
parent2bc895a829caa459e032e12e1d117994dd510a5c (diff)
parent6fe9d4e61e30622abdbf4877daf5653d7339e4e8 (diff)
downloadspark-93091f6936262a4006d875bf69b3f8c31c291617.tar.gz
spark-93091f6936262a4006d875bf69b3f8c31c291617.tar.bz2
spark-93091f6936262a4006d875bf69b3f8c31c291617.zip
Merge branch 'master' of github.com:mesos/spark into blockmanager
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/RDD.scala14
-rw-r--r--core/src/main/scala/spark/SparkContext.scala49
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDD.scala16
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala166
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala16
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala3
-rw-r--r--core/src/main/scala/spark/storage/StorageUtils.scala33
-rw-r--r--core/src/test/scala/spark/DistributedSuite.scala34
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala24
-rw-r--r--core/src/test/scala/spark/storage/BlockManagerSuite.scala25
10 files changed, 297 insertions, 83 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 09e52ebf3e..fd14ef17f1 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -107,7 +107,7 @@ abstract class RDD[T: ClassManifest](
// =======================================================================
/** A unique ID for this RDD (within its SparkContext). */
- val id = sc.newRddId()
+ val id: Int = sc.newRddId()
/** A friendly name for this RDD */
var name: String = null
@@ -120,7 +120,8 @@ abstract class RDD[T: ClassManifest](
/**
* Set this RDD's storage level to persist its values across operations after the first time
- * it is computed. Can only be called once on each RDD.
+ * it is computed. This can only be used to assign a new storage level if the RDD does not
+ * have a storage level set yet..
*/
def persist(newLevel: StorageLevel): RDD[T] = {
// TODO: Handle changes of StorageLevel
@@ -140,6 +141,15 @@ abstract class RDD[T: ClassManifest](
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): RDD[T] = persist()
+ /** Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. */
+ def unpersist(): RDD[T] = {
+ logInfo("Removing RDD " + id + " from persistence list")
+ sc.env.blockManager.master.removeRdd(id)
+ sc.persistentRdds.remove(id)
+ storageLevel = StorageLevel.NONE
+ this
+ }
+
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 5f5ec0b0f4..2ae4ad8659 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -1,47 +1,50 @@
package spark
import java.io._
-import java.util.concurrent.atomic.AtomicInteger
import java.net.URI
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.generic.Growable
-import scala.collection.mutable.HashMap
-import scala.collection.JavaConversions._
+import scala.collection.mutable.{ConcurrentMap, HashMap}
+
+import akka.actor.Actor._
-import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.mapred.InputFormat
-import org.apache.hadoop.mapred.SequenceFileInputFormat
-import org.apache.hadoop.io.Writable
-import org.apache.hadoop.io.IntWritable
-import org.apache.hadoop.io.LongWritable
-import org.apache.hadoop.io.FloatWritable
-import org.apache.hadoop.io.DoubleWritable
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.ArrayWritable
import org.apache.hadoop.io.BooleanWritable
import org.apache.hadoop.io.BytesWritable
-import org.apache.hadoop.io.ArrayWritable
+import org.apache.hadoop.io.DoubleWritable
+import org.apache.hadoop.io.FloatWritable
+import org.apache.hadoop.io.IntWritable
+import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
+import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.FileInputFormat
+import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.SequenceFileInputFormat
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
-import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
+import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
+
import org.apache.mesos.MesosNativeLibrary
-import spark.deploy.{SparkHadoopUtil, LocalSparkCluster}
-import spark.partial.ApproximateEvaluator
-import spark.partial.PartialResult
+import spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
+import spark.partial.{ApproximateEvaluator, PartialResult}
import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
-import spark.scheduler._
+import spark.scheduler.{DAGScheduler, ResultTask, ShuffleMapTask, SparkListener, SplitInfo, Stage, StageInfo, TaskScheduler}
+import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, ClusterScheduler}
import spark.scheduler.local.LocalScheduler
-import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import spark.storage.BlockManagerUI
+import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo}
import spark.util.{MetadataCleaner, TimeStampedHashMap}
-import spark.storage.{StorageStatus, StorageUtils, RDDInfo}
+
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -97,7 +100,7 @@ class SparkContext(
private[spark] val addedJars = HashMap[String, Long]()
// Keeps track of all persisted RDDs
- private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]()
+ private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
@@ -505,7 +508,7 @@ class SparkContext(
* Return information about what RDDs are cached, if they are in mem or on disk, how much space
* they take, etc.
*/
- def getRDDStorageInfo : Array[RDDInfo] = {
+ def getRDDStorageInfo: Array[RDDInfo] = {
StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
}
@@ -516,7 +519,7 @@ class SparkContext(
/**
* Return information about blocks stored in all of the slaves
*/
- def getExecutorStorageStatus : Array[StorageStatus] = {
+ def getExecutorStorageStatus: Array[StorageStatus] = {
env.blockManager.master.getStorageStatus
}
diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala
index e29f1e5899..eb81ed64cd 100644
--- a/core/src/main/scala/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDD.scala
@@ -14,12 +14,18 @@ JavaRDDLike[T, JavaRDD[T]] {
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaRDD[T] = wrapRDD(rdd.cache())
- /**
+ /**
* Set this RDD's storage level to persist its values across operations after the first time
- * it is computed. Can only be called once on each RDD.
+ * it is computed. This can only be used to assign a new storage level if the RDD does not
+ * have a storage level set yet..
*/
def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel))
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ */
+ def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
+
// Transformations (return a new RDD)
/**
@@ -31,7 +37,7 @@ JavaRDDLike[T, JavaRDD[T]] {
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(numPartitions: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numPartitions))
-
+
/**
* Return a new RDD containing only the elements that satisfy a predicate.
*/
@@ -54,7 +60,7 @@ JavaRDDLike[T, JavaRDD[T]] {
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
wrapRDD(rdd.sample(withReplacement, fraction, seed))
-
+
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
@@ -63,7 +69,7 @@ JavaRDDLike[T, JavaRDD[T]] {
/**
* Return an RDD with the elements from `this` that are not in `other`.
- *
+ *
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index 0c6bdb1559..0eb03630d0 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -188,6 +188,38 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
} )
}
+ // MUST be called within selector loop - else deadlock.
+ private def triggerForceCloseByException(key: SelectionKey, e: Exception) {
+ try {
+ key.interestOps(0)
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ // Pushing to connect threadpool
+ handleConnectExecutor.execute(new Runnable {
+ override def run() {
+ try {
+ conn.callOnExceptionCallback(e)
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+ try {
+ conn.close()
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+ }
+ })
+ }
+
+
def run() {
try {
while(!selectorThread.isInterrupted) {
@@ -200,29 +232,76 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
while(!keyInterestChangeRequests.isEmpty) {
val (key, ops) = keyInterestChangeRequests.dequeue
- val connection = connectionsByKey.getOrElse(key, null)
- if (connection != null) {
- val lastOps = key.interestOps()
- key.interestOps(ops)
-
- // hot loop - prevent materialization of string if trace not enabled.
- if (isTraceEnabled()) {
- def intToOpStr(op: Int): String = {
- val opStrs = ArrayBuffer[String]()
- if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
- if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
- if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
- if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
- if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
- }
- logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() +
- "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
+ try {
+ if (key.isValid) {
+ val connection = connectionsByKey.getOrElse(key, null)
+ if (connection != null) {
+ val lastOps = key.interestOps()
+ key.interestOps(ops)
+
+ // hot loop - prevent materialization of string if trace not enabled.
+ if (isTraceEnabled()) {
+ def intToOpStr(op: Int): String = {
+ val opStrs = ArrayBuffer[String]()
+ if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
+ if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
+ if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
+ if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
+ if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
+ }
+
+ logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() +
+ "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
+ }
+ }
+ } else {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
}
}
}
- val selectedKeysCount = selector.select()
+ val selectedKeysCount =
+ try {
+ selector.select()
+ } catch {
+ // Explicitly only dealing with CancelledKeyException here since other exceptions should be dealt with differently.
+ case e: CancelledKeyException => {
+ // Some keys within the selectors list are invalid/closed. clear them.
+ val allKeys = selector.keys().iterator()
+
+ while (allKeys.hasNext()) {
+ val key = allKeys.next()
+ try {
+ if (! key.isValid) {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ }
+ }
+ }
+ 0
+ }
+
if (selectedKeysCount == 0) {
logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
}
@@ -230,23 +309,40 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logInfo("Selector thread was interrupted!")
return
}
-
- val selectedKeys = selector.selectedKeys().iterator()
- while (selectedKeys.hasNext()) {
- val key = selectedKeys.next
- selectedKeys.remove()
- if (key.isValid) {
- if (key.isAcceptable) {
- acceptConnection(key)
- } else
- if (key.isConnectable) {
- triggerConnect(key)
- } else
- if (key.isReadable) {
- triggerRead(key)
- } else
- if (key.isWritable) {
- triggerWrite(key)
+
+ if (0 != selectedKeysCount) {
+ val selectedKeys = selector.selectedKeys().iterator()
+ while (selectedKeys.hasNext()) {
+ val key = selectedKeys.next
+ selectedKeys.remove()
+ try {
+ if (key.isValid) {
+ if (key.isAcceptable) {
+ acceptConnection(key)
+ } else
+ if (key.isConnectable) {
+ triggerConnect(key)
+ } else
+ if (key.isReadable) {
+ triggerRead(key)
+ } else
+ if (key.isWritable) {
+ triggerWrite(key)
+ }
+ } else {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ // weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException.
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
}
}
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index 6fae62d373..ac26c16867 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -15,6 +15,7 @@ import akka.util.duration._
import spark.{Logging, SparkException, Utils}
+
private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging {
val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
@@ -88,6 +89,21 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
}
/**
+ * Remove all blocks belonging to the given RDD.
+ */
+ def removeRdd(rddId: Int) {
+ val rddBlockPrefix = "rdd_" + rddId + "_"
+ // Get the list of blocks in block manager, and remove ones that are part of this RDD.
+ // The runtime complexity is linear to the number of blocks persisted in the cluster.
+ // It could be expensive if the cluster is large and has a lot of blocks persisted.
+ getStorageStatus.flatMap(_.blocks).foreach { case(blockId, status) =>
+ if (blockId.startsWith(rddBlockPrefix)) {
+ removeBlock(blockId)
+ }
+ }
+ }
+
+ /**
* Return the memory status for each block manager, in the form of a map from
* the block manager's id to two long values. The first value is the maximum
* amount of memory allocated for the block manager, while the second is the
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index 498bc9eeb6..8154b8ca74 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -236,8 +236,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
localDir = new File(rootDir, "spark-local-" + localDirId)
if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
+ foundLocalDir = localDir.mkdirs()
}
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala
index dec47a9d41..8f52168c24 100644
--- a/core/src/main/scala/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/spark/storage/StorageUtils.scala
@@ -4,9 +4,9 @@ import spark.{Utils, SparkContext}
import BlockManagerMasterActor.BlockStatus
private[spark]
-case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
+case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
blocks: Map[String, BlockStatus]) {
-
+
def memUsed(blockPrefix: String = "") = {
blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize).
reduceOption(_+_).getOrElse(0l)
@@ -22,35 +22,40 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
}
case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
- numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) {
+ numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long)
+ extends Ordered[RDDInfo] {
override def toString = {
import Utils.memoryBytesToString
"RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id,
storageLevel.toString, numCachedPartitions, numPartitions, memoryBytesToString(memSize), memoryBytesToString(diskSize))
}
+
+ override def compare(that: RDDInfo) = {
+ this.id - that.id
+ }
}
/* Helper methods for storage-related objects */
private[spark]
object StorageUtils {
- /* Given the current storage status of the BlockManager, returns information for each RDD */
- def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus],
+ /* Given the current storage status of the BlockManager, returns information for each RDD */
+ def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus],
sc: SparkContext) : Array[RDDInfo] = {
- rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
+ rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
}
- /* Given a list of BlockStatus objets, returns information for each RDD */
- def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
+ /* Given a list of BlockStatus objets, returns information for each RDD */
+ def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
sc: SparkContext) : Array[RDDInfo] = {
// Group by rddId, ignore the partition name
- val groupedRddBlocks = infos.groupBy { case(k, v) =>
+ val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) =>
k.substring(0,k.lastIndexOf('_'))
}.mapValues(_.values.toArray)
// For each RDD, generate an RDDInfo object
- groupedRddBlocks.map { case(rddKey, rddBlocks) =>
+ val rddInfos = groupedRddBlocks.map { case(rddKey, rddBlocks) =>
// Add up memory and disk sizes
val memSize = rddBlocks.map(_.memSize).reduce(_ + _)
@@ -65,10 +70,14 @@ object StorageUtils {
RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.partitions.size, memSize, diskSize)
}.toArray
+
+ scala.util.Sorting.quickSort(rddInfos)
+
+ rddInfos
}
- /* Removes all BlockStatus object that are not part of a block prefix */
- def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
+ /* Removes all BlockStatus object that are not part of a block prefix */
+ def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
prefix: String) : Array[StorageStatus] = {
storageStatusList.map { status =>
diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala
index c9b4707def..ab3e197035 100644
--- a/core/src/test/scala/spark/DistributedSuite.scala
+++ b/core/src/test/scala/spark/DistributedSuite.scala
@@ -3,8 +3,10 @@ package spark
import network.ConnectionManagerId
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
+import org.scalatest.concurrent.Timeouts._
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.prop.Checkers
+import org.scalatest.time.{Span, Millis}
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
import org.scalacheck.Prop._
@@ -252,12 +254,36 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
assert(data2.count === 2)
}
}
+
+ test("unpersist RDDs") {
+ DistributedSuite.amMaster = true
+ sc = new SparkContext("local-cluster[3,1,512]", "test")
+ val data = sc.parallelize(Seq(true, false, false, false), 4)
+ data.persist(StorageLevel.MEMORY_ONLY_2)
+ data.count
+ assert(sc.persistentRdds.isEmpty == false)
+ data.unpersist()
+ assert(sc.persistentRdds.isEmpty == true)
+
+ failAfter(Span(3000, Millis)) {
+ try {
+ while (! sc.getRDDStorageInfo.isEmpty) {
+ Thread.sleep(200)
+ }
+ } catch {
+ case e: Exception =>
+ // Do nothing. We might see exceptions because block manager
+ // is racing this thread to remove entries from the driver.
+ }
+ }
+ assert(sc.getRDDStorageInfo.isEmpty == true)
+ }
}
object DistributedSuite {
// Indicates whether this JVM is marked for failure.
var mark = false
-
+
// Set by test to remember if we are in the driver program so we can assert
// that we are not.
var amMaster = false
@@ -274,9 +300,9 @@ object DistributedSuite {
// Act like an identity function, but if mark was set to true previously, fail,
// crashing the entire JVM.
def failOnMarkedIdentity(item: Boolean): Boolean = {
- if (mark) {
+ if (mark) {
System.exit(42)
- }
+ }
item
- }
+ }
}
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 7fbdd44340..cee6312572 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -2,6 +2,8 @@ package spark
import scala.collection.mutable.HashMap
import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.time.{Span, Millis}
import spark.SparkContext._
import spark.rdd.{CoalescedRDD, CoGroupedRDD, PartitionPruningRDD, ShuffledRDD}
@@ -100,6 +102,28 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(rdd.collect().toList === List(1, 2, 3, 4))
}
+ test("unpersist RDD") {
+ sc = new SparkContext("local", "test")
+ val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
+ rdd.count
+ assert(sc.persistentRdds.isEmpty == false)
+ rdd.unpersist()
+ assert(sc.persistentRdds.isEmpty == true)
+
+ failAfter(Span(3000, Millis)) {
+ try {
+ while (! sc.getRDDStorageInfo.isEmpty) {
+ Thread.sleep(200)
+ }
+ } catch {
+ case e: Exception =>
+ // Do nothing. We might see exceptions because block manager
+ // is racing this thread to remove entries from the driver.
+ }
+ }
+ assert(sc.getRDDStorageInfo.isEmpty == true)
+ }
+
test("caching with failures") {
sc = new SparkContext("local", "test")
val onlySplit = new Partition { override def index: Int = 0 }
diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
index 5a11a4483b..9fe0de665c 100644
--- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
@@ -207,6 +207,31 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
}
+ test("removing rdd") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ // Putting a1, a2 and a3 in memory.
+ store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY)
+ master.removeRdd(0)
+
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("rdd_0_0") should be (None)
+ master.getLocations("rdd_0_0") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("rdd_0_1") should be (None)
+ master.getLocations("rdd_0_1") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("nonrddblock") should not be (None)
+ master.getLocations("nonrddblock") should have size (1)
+ }
+ }
+
test("reregistration on heart beat") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)