aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2012-07-27 12:00:49 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2012-07-27 12:00:49 -0700
commit024905f682b6b683b6cb0ca8e1ea3e277fbd6c9d (patch)
tree9d6835e5f3a3be9b2874fa73eb3ba4a369e48dd4
parentd1b7f41671feb6e17e98383b1770757b4941cc3b (diff)
downloadspark-024905f682b6b683b6cb0ca8e1ea3e277fbd6c9d.tar.gz
spark-024905f682b6b683b6cb0ca8e1ea3e277fbd6c9d.tar.bz2
spark-024905f682b6b683b6cb0ca8e1ea3e277fbd6c9d.zip
Added BlockRDD and a first-cut version of checkpoint() to RDD class.
-rw-r--r--core/src/main/scala/spark/BlockRDD.scala42
-rw-r--r--core/src/main/scala/spark/RDD.scala14
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala7
3 files changed, 63 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/BlockRDD.scala b/core/src/main/scala/spark/BlockRDD.scala
new file mode 100644
index 0000000000..ea009f0f4f
--- /dev/null
+++ b/core/src/main/scala/spark/BlockRDD.scala
@@ -0,0 +1,42 @@
+package spark
+
+import scala.collection.mutable.HashMap
+
+class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
+ val index = idx
+}
+
+
+class BlockRDD[T: ClassManifest](sc: SparkContext, blockIds: Array[String]) extends RDD[T](sc) {
+
+ @transient
+ val splits_ = (0 until blockIds.size).map(i => {
+ new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
+ }).toArray
+
+ @transient
+ lazy val locations_ = {
+ val blockManager = SparkEnv.get.blockManager
+ /*val locations = blockIds.map(id => blockManager.getLocations(id))*/
+ val locations = blockManager.getLocations(blockIds)
+ HashMap(blockIds.zip(locations):_*)
+ }
+
+ override def splits = splits_
+
+ override def compute(split: Split): Iterator[T] = {
+ val blockManager = SparkEnv.get.blockManager
+ val blockId = split.asInstanceOf[BlockRDDSplit].blockId
+ blockManager.get(blockId) match {
+ case Some(block) => block.asInstanceOf[Iterator[T]]
+ case None =>
+ throw new Exception("Could not compute split, block " + blockId + " not found")
+ }
+ }
+
+ override def preferredLocations(split: Split) =
+ locations_(split.asInstanceOf[BlockRDDSplit].blockId)
+
+ override val dependencies: List[Dependency[_]] = Nil
+}
+
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 1191523ccc..1190e64f8f 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -94,6 +94,20 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def getStorageLevel = storageLevel
+ def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER): RDD[T] = {
+ if (!level.useDisk && level.replication < 2) {
+ throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
+ }
+
+ // This is a hack. Ideally this should re-use the code used by the CacheTracker
+ // to generate the key.
+ def getSplitKey(split: Split) = "rdd:%d:%d".format(this.id, split.index)
+
+ persist(level)
+ sc.runJob(this, (iter: Iterator[T]) => {} )
+ new BlockRDD[T](sc, splits.map(getSplitKey).toArray)
+ }
+
// Read this RDD; will read from cache if applicable, or otherwise compute
final def iterator(split: Split): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 7199b634b7..8f39820178 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -42,4 +42,11 @@ class RDDSuite extends FunSuite {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
sc.stop()
}
+
+ test("checkpointing") {
+ val sc = new SparkContext("local", "test")
+ val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).flatMap(x => 1 to x).checkpoint()
+ assert(rdd.collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
+ sc.stop()
+ }
}