aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2012-10-30 16:09:37 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2012-10-30 16:09:37 -0700
commit0dcd770fdc4d558972b635b6770ed0120280ef22 (patch)
tree46dc6ea8e9cc829455105956a0446d614286b591
parentac12abc17ff90ec99192f3c3de4d1d390969e635 (diff)
downloadspark-0dcd770fdc4d558972b635b6770ed0120280ef22.tar.gz
spark-0dcd770fdc4d558972b635b6770ed0120280ef22.tar.bz2
spark-0dcd770fdc4d558972b635b6770ed0120280ef22.zip
Added checkpointing support to all RDDs, along with CheckpointSuite to test checkpointing in them.
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala4
-rw-r--r--core/src/main/scala/spark/ParallelCollection.scala4
-rw-r--r--core/src/main/scala/spark/RDD.scala129
-rw-r--r--core/src/main/scala/spark/SparkContext.scala21
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala13
-rw-r--r--core/src/main/scala/spark/rdd/CartesianRDD.scala38
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala19
-rw-r--r--core/src/main/scala/spark/rdd/CoalescedRDD.scala26
-rw-r--r--core/src/main/scala/spark/rdd/FilteredRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/FlatMappedRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/GlommedRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/HadoopRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/MappedRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala9
-rw-r--r--core/src/main/scala/spark/rdd/SampledRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala5
-rw-r--r--core/src/main/scala/spark/rdd/UnionRDD.scala32
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala116
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala25
22 files changed, 352 insertions, 107 deletions
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index f52af08125..1f82bd3ab8 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -625,7 +625,7 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
}
private[spark]
-class MappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V => U)
+class MappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => U)
extends RDD[(K, U)](prev.get) {
override def splits = firstParent[(K, V)].splits
@@ -634,7 +634,7 @@ class MappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V
}
private[spark]
-class FlatMappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U])
+class FlatMappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U])
extends RDD[(K, U)](prev.get) {
override def splits = firstParent[(K, V)].splits
diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala
index ad06ee9736..9725017b61 100644
--- a/core/src/main/scala/spark/ParallelCollection.scala
+++ b/core/src/main/scala/spark/ParallelCollection.scala
@@ -22,10 +22,10 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest](
}
private[spark] class ParallelCollection[T: ClassManifest](
- @transient sc_ : SparkContext,
+ @transient sc : SparkContext,
@transient data: Seq[T],
numSlices: Int)
- extends RDD[T](sc_, Nil) {
+ extends RDD[T](sc, Nil) {
// TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
// cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
// instead. UPDATE: With the new changes to enable checkpointing, this an be done.
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index c9f3763f73..e272a0ede9 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -13,6 +13,7 @@ import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions.mapAsScalaMap
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
@@ -74,7 +75,7 @@ import SparkContext._
*/
abstract class RDD[T: ClassManifest](
@transient var sc: SparkContext,
- @transient var dependencies_ : List[Dependency[_]] = Nil
+ var dependencies_ : List[Dependency[_]]
) extends Serializable {
@@ -91,7 +92,6 @@ abstract class RDD[T: ClassManifest](
/** How this RDD depends on any parent RDDs. */
def dependencies: List[Dependency[_]] = dependencies_
- //var dependencies: List[Dependency[_]] = dependencies_
/** Record user function generating this RDD. */
private[spark] val origin = Utils.getSparkCallSite
@@ -100,7 +100,13 @@ abstract class RDD[T: ClassManifest](
val partitioner: Option[Partitioner] = None
/** Optionally overridden by subclasses to specify placement preferences. */
- def preferredLocations(split: Split): Seq[String] = Nil
+ def preferredLocations(split: Split): Seq[String] = {
+ if (isCheckpointed) {
+ checkpointRDD.preferredLocations(split)
+ } else {
+ Nil
+ }
+ }
/** The [[spark.SparkContext]] that this RDD was created on. */
def context = sc
@@ -113,8 +119,23 @@ abstract class RDD[T: ClassManifest](
// Variables relating to persistence
private var storageLevel: StorageLevel = StorageLevel.NONE
- private[spark] def firstParent[U: ClassManifest] = dependencies.head.rdd.asInstanceOf[RDD[U]]
- private[spark] def parent[U: ClassManifest](id: Int) = dependencies(id).rdd.asInstanceOf[RDD[U]]
+ /** Returns the first parent RDD */
+ private[spark] def firstParent[U: ClassManifest] = {
+ dependencies.head.rdd.asInstanceOf[RDD[U]]
+ }
+
+ /** Returns the `i` th parent RDD */
+ private[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]]
+
+ // Variables relating to checkpointing
+ val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD
+ var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing
+ var isCheckpointInProgress = false // set to true when checkpointing is in progress
+ var isCheckpointed = false // set to true after checkpointing is completed
+
+ var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed
+ var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file
+ var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD
// Methods available on all RDDs:
@@ -141,32 +162,94 @@ abstract class RDD[T: ClassManifest](
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
- private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
- if (!level.useDisk && level.replication < 2) {
- throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
- }
-
- // This is a hack. Ideally this should re-use the code used by the CacheTracker
- // to generate the key.
- def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index)
-
- persist(level)
- sc.runJob(this, (iter: Iterator[T]) => {} )
-
- val p = this.partitioner
-
- new BlockRDD[T](sc, splits.map(getSplitKey).toArray) {
- override val partitioner = p
+ /**
+ * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir`
+ * (set using setCheckpointDir()) and all references to its parent RDDs will be removed.
+ * This is used to truncate very long lineages. In the current implementation, Spark will save
+ * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done.
+ * Hence, it is strongly recommended to use checkpoint() on RDDs when
+ * (i) Checkpoint() is called before the any job has been executed on this RDD.
+ * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will
+ * require recomputation.
+ */
+ protected[spark] def checkpoint() {
+ synchronized {
+ if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) {
+ // do nothing
+ } else if (isCheckpointable) {
+ shouldCheckpoint = true
+ } else {
+ throw new Exception(this + " cannot be checkpointed")
+ }
}
}
-
+
+ /**
+ * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after a job
+ * using this RDD has completed (therefore the RDD has been materialized and
+ * potentially stored in memory). In case this RDD is not marked for checkpointing,
+ * doCheckpoint() is called recursively on the parent RDDs.
+ */
+ private[spark] def doCheckpoint() {
+ val startCheckpoint = synchronized {
+ if (isCheckpointable && shouldCheckpoint && !isCheckpointInProgress) {
+ isCheckpointInProgress = true
+ true
+ } else {
+ false
+ }
+ }
+
+ if (startCheckpoint) {
+ val rdd = this
+ val env = SparkEnv.get
+
+ // Spawn a new thread to do the checkpoint as it takes sometime to write the RDD to file
+ val th = new Thread() {
+ override def run() {
+ // Save the RDD to a file, create a new HadoopRDD from it,
+ // and change the dependencies from the original parents to the new RDD
+ SparkEnv.set(env)
+ rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString
+ rdd.saveAsObjectFile(checkpointFile)
+ rdd.synchronized {
+ rdd.checkpointRDD = context.objectFile[T](checkpointFile)
+ rdd.checkpointRDDSplits = rdd.checkpointRDD.splits
+ rdd.changeDependencies(rdd.checkpointRDD)
+ rdd.shouldCheckpoint = false
+ rdd.isCheckpointInProgress = false
+ rdd.isCheckpointed = true
+ }
+ }
+ }
+ th.start()
+ } else {
+ // Recursively call doCheckpoint() to perform checkpointing on parent RDD if they are marked
+ dependencies.foreach(_.rdd.doCheckpoint())
+ }
+ }
+
+ /**
+ * Changes the dependencies of this RDD from its original parents to the new [[spark.rdd.HadoopRDD]]
+ * (`newRDD`) created from the checkpoint file. This method must ensure that all references
+ * to the original parent RDDs must be removed to enable the parent RDDs to be garbage
+ * collected. Subclasses of RDD may override this method for implementing their own changing
+ * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea.
+ */
+ protected def changeDependencies(newRDD: RDD[_]) {
+ dependencies_ = List(new OneToOneDependency(newRDD))
+ }
+
/**
* Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
* This should ''not'' be called by users directly, but is available for implementors of custom
* subclasses of RDD.
*/
final def iterator(split: Split): Iterator[T] = {
- if (storageLevel != StorageLevel.NONE) {
+ if (isCheckpointed) {
+ // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original
+ checkpointRDD.iterator(checkpointRDDSplits(split.index))
+ } else if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel)
} else {
compute(split)
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 6b957a6356..79ceab5f4f 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -188,6 +188,8 @@ class SparkContext(
private var dagScheduler = new DAGScheduler(taskScheduler)
+ private[spark] var checkpointDir: String = null
+
// Methods for creating RDDs
/** Distribute a local Scala collection to form an RDD. */
@@ -519,6 +521,7 @@ class SparkContext(
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
+ rdd.doCheckpoint()
result
}
@@ -575,6 +578,24 @@ class SparkContext(
return f
}
+ /**
+ * Set the directory under which RDDs are going to be checkpointed. This method will
+ * create this directory and will throw an exception of the path already exists (to avoid
+ * overwriting existing files may be overwritten). The directory will be deleted on exit
+ * if indicated.
+ */
+ def setCheckpointDir(dir: String, deleteOnExit: Boolean = false) {
+ val path = new Path(dir)
+ val fs = path.getFileSystem(new Configuration())
+ if (fs.exists(path)) {
+ throw new Exception("Checkpoint directory '" + path + "' already exists.")
+ } else {
+ fs.mkdirs(path)
+ if (deleteOnExit) fs.deleteOnExit(path)
+ }
+ checkpointDir = dir
+ }
+
/** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */
def defaultParallelism: Int = taskScheduler.defaultParallelism
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index cb73976aed..f4c3f99011 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -14,7 +14,7 @@ private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split
private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
- extends RDD[T](sc) {
+ extends RDD[T](sc, Nil) {
@transient
val splits_ = (0 until blockIds.size).map(i => {
@@ -41,9 +41,12 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
}
}
- override def preferredLocations(split: Split) =
- locations_(split.asInstanceOf[BlockRDDSplit].blockId)
-
- override val dependencies: List[Dependency[_]] = Nil
+ override def preferredLocations(split: Split) = {
+ if (isCheckpointed) {
+ checkpointRDD.preferredLocations(split)
+ } else {
+ locations_(split.asInstanceOf[BlockRDDSplit].blockId)
+ }
+ }
}
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index c97b835630..458ad38d55 100644
--- a/core/src/main/scala/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -1,9 +1,6 @@
package spark.rdd
-import spark.NarrowDependency
-import spark.RDD
-import spark.SparkContext
-import spark.Split
+import spark._
import java.lang.ref.WeakReference
private[spark]
@@ -14,19 +11,15 @@ class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with
private[spark]
class CartesianRDD[T: ClassManifest, U:ClassManifest](
sc: SparkContext,
- rdd1_ : WeakReference[RDD[T]],
- rdd2_ : WeakReference[RDD[U]])
- extends RDD[Pair[T, U]](sc)
+ var rdd1 : RDD[T],
+ var rdd2 : RDD[U])
+ extends RDD[Pair[T, U]](sc, Nil)
with Serializable {
- def rdd1 = rdd1_.get
- def rdd2 = rdd2_.get
-
val numSplitsInRdd2 = rdd2.splits.size
- // TODO: make this null when finishing checkpoint
@transient
- val splits_ = {
+ var splits_ = {
// create the cross product split
val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
@@ -36,12 +29,15 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
array
}
- // TODO: make this return checkpoint Hadoop RDDs split when checkpointed
override def splits = splits_
override def preferredLocations(split: Split) = {
- val currSplit = split.asInstanceOf[CartesianSplit]
- rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
+ if (isCheckpointed) {
+ checkpointRDD.preferredLocations(split)
+ } else {
+ val currSplit = split.asInstanceOf[CartesianSplit]
+ rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
+ }
}
override def compute(split: Split) = {
@@ -49,8 +45,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y)
}
- // TODO: make this null when finishing checkpoint
- var deps = List(
+ var deps_ = List(
new NarrowDependency(rdd1) {
def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
},
@@ -59,5 +54,12 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
}
)
- override def dependencies = deps
+ override def dependencies = deps_
+
+ override protected def changeDependencies(newRDD: RDD[_]) {
+ deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]]))
+ splits_ = newRDD.splits
+ rdd1 = null
+ rdd2 = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index af54ac2fa0..a313ebcbe8 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -30,14 +30,13 @@ private[spark] class CoGroupAggregator
{ (b1, b2) => b1 ++ b2 })
with Serializable
-class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
+class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging {
val aggr = new CoGroupAggregator
- // TODO: make this null when finishing checkpoint
@transient
- var deps = {
+ var deps_ = {
val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) {
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
@@ -52,11 +51,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
deps.toList
}
- override def dependencies = deps
+ override def dependencies = deps_
- // TODO: make this null when finishing checkpoint
@transient
- val splits_ : Array[Split] = {
+ var splits_ : Array[Split] = {
val firstRdd = rdds.head
val array = new Array[Split](part.numPartitions)
for (i <- 0 until array.size) {
@@ -72,13 +70,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
array
}
- // TODO: make this return checkpoint Hadoop RDDs split when checkpointed
override def splits = splits_
override val partitioner = Some(part)
- override def preferredLocations(s: Split) = Nil
-
override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
val numRdds = split.deps.size
@@ -106,4 +101,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
map.iterator
}
+
+ override protected def changeDependencies(newRDD: RDD[_]) {
+ deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]]))
+ splits_ = newRDD.splits
+ rdds = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index 573acf8893..5b5f72ddeb 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -1,8 +1,7 @@
package spark.rdd
-import spark.NarrowDependency
-import spark.RDD
-import spark.Split
+import spark._
+import java.lang.ref.WeakReference
private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split
@@ -15,13 +14,12 @@ private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) exten
* or to avoid having a large number of small tasks when processing a directory with many files.
*/
class CoalescedRDD[T: ClassManifest](
- @transient prev: RDD[T], // TODO: Make this a weak reference
+ var prev: RDD[T],
maxPartitions: Int)
extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs
- // TODO: make this null when finishing checkpoint
- @transient val splits_ : Array[Split] = {
- val prevSplits = firstParent[T].splits
+ @transient var splits_ : Array[Split] = {
+ val prevSplits = prev.splits
if (prevSplits.length < maxPartitions) {
prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) }
} else {
@@ -33,7 +31,6 @@ class CoalescedRDD[T: ClassManifest](
}
}
- // TODO: make this return checkpoint Hadoop RDDs split when checkpointed
override def splits = splits_
override def compute(split: Split): Iterator[T] = {
@@ -42,13 +39,18 @@ class CoalescedRDD[T: ClassManifest](
}
}
- // TODO: make this null when finishing checkpoint
- var deps = List(
- new NarrowDependency(firstParent) {
+ var deps_ : List[Dependency[_]] = List(
+ new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index)
}
)
- override def dependencies = deps
+ override def dependencies = deps_
+
+ override protected def changeDependencies(newRDD: RDD[_]) {
+ deps_ = List(new OneToOneDependency(newRDD))
+ splits_ = newRDD.splits
+ prev = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
index cc2a3acd3a..1370cf6faf 100644
--- a/core/src/main/scala/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -7,7 +7,7 @@ import java.lang.ref.WeakReference
private[spark]
class FilteredRDD[T: ClassManifest](
- @transient prev: WeakReference[RDD[T]],
+ prev: WeakReference[RDD[T]],
f: T => Boolean)
extends RDD[T](prev.get) {
diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
index 34bd784c13..6b2cc67568 100644
--- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
@@ -7,7 +7,7 @@ import java.lang.ref.WeakReference
private[spark]
class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
- @transient prev: WeakReference[RDD[T]],
+ prev: WeakReference[RDD[T]],
f: T => TraversableOnce[U])
extends RDD[U](prev.get) {
diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala
index 9321e89dcd..0f0b6ab0ff 100644
--- a/core/src/main/scala/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala
@@ -6,7 +6,7 @@ import spark.Split
import java.lang.ref.WeakReference
private[spark]
-class GlommedRDD[T: ClassManifest](@transient prev: WeakReference[RDD[T]])
+class GlommedRDD[T: ClassManifest](prev: WeakReference[RDD[T]])
extends RDD[Array[T]](prev.get) {
override def splits = firstParent[T].splits
override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index a12531ea89..19ed56d9c0 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -115,4 +115,6 @@ class HadoopRDD[K, V](
val hadoopSplit = split.asInstanceOf[HadoopSplit]
hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
}
+
+ override val isCheckpointable = false
}
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
index bad872c430..b04f56cfcc 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
@@ -7,7 +7,7 @@ import java.lang.ref.WeakReference
private[spark]
class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
- @transient prev: WeakReference[RDD[T]],
+ prev: WeakReference[RDD[T]],
f: Iterator[T] => Iterator[U],
preservesPartitioning: Boolean = false)
extends RDD[U](prev.get) {
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
index d7b238b05d..7a4b6ffb03 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
@@ -12,7 +12,7 @@ import java.lang.ref.WeakReference
*/
private[spark]
class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
- @transient prev: WeakReference[RDD[T]],
+ prev: WeakReference[RDD[T]],
f: (Int, Iterator[T]) => Iterator[U])
extends RDD[U](prev.get) {
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
index 126c6f332b..8fa1872e0a 100644
--- a/core/src/main/scala/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -7,7 +7,7 @@ import java.lang.ref.WeakReference
private[spark]
class MappedRDD[U: ClassManifest, T: ClassManifest](
- @transient prev: WeakReference[RDD[T]],
+ prev: WeakReference[RDD[T]],
f: T => U)
extends RDD[U](prev.get) {
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index c12df5839e..2875abb2db 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -93,4 +93,6 @@ class NewHadoopRDD[K, V](
val theSplit = split.asInstanceOf[NewHadoopSplit]
theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
}
+
+ override val isCheckpointable = false
}
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index d54579d6d1..d9293a9d1a 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -12,6 +12,7 @@ import spark.OneToOneDependency
import spark.RDD
import spark.SparkEnv
import spark.Split
+import java.lang.ref.WeakReference
/**
@@ -19,16 +20,16 @@ import spark.Split
* (printing them one per line) and returns the output as a collection of strings.
*/
class PipedRDD[T: ClassManifest](
- @transient prev: RDD[T],
+ prev: WeakReference[RDD[T]],
command: Seq[String],
envVars: Map[String, String])
- extends RDD[String](prev) {
+ extends RDD[String](prev.get) {
- def this(@transient prev: RDD[T], command: Seq[String]) = this(prev, command, Map())
+ def this(prev: WeakReference[RDD[T]], command: Seq[String]) = this(prev, command, Map())
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
- def this(@transient prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command))
+ def this(prev: WeakReference[RDD[T]], command: String) = this(prev, PipedRDD.tokenize(command))
override def splits = firstParent[T].splits
diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index 00b521b130..f273f257f8 100644
--- a/core/src/main/scala/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -15,7 +15,7 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali
}
class SampledRDD[T: ClassManifest](
- @transient prev: WeakReference[RDD[T]],
+ prev: WeakReference[RDD[T]],
withReplacement: Boolean,
frac: Double,
seed: Int)
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 62867dab4f..b7d843c26d 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -31,11 +31,6 @@ class ShuffledRDD[K, V](
override def splits = splits_
- override def preferredLocations(split: Split) = Nil
-
- //val dep = new ShuffleDependency(parent, part)
- //override val dependencies = List(dep)
-
override def compute(split: Split): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index 0a61a2d1f5..643a174160 100644
--- a/core/src/main/scala/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -2,11 +2,7 @@ package spark.rdd
import scala.collection.mutable.ArrayBuffer
-import spark.Dependency
-import spark.RangeDependency
-import spark.RDD
-import spark.SparkContext
-import spark.Split
+import spark._
import java.lang.ref.WeakReference
private[spark] class UnionSplit[T: ClassManifest](
@@ -23,12 +19,11 @@ private[spark] class UnionSplit[T: ClassManifest](
class UnionRDD[T: ClassManifest](
sc: SparkContext,
- @transient rdds: Seq[RDD[T]]) // TODO: Make this a weak reference
+ @transient var rdds: Seq[RDD[T]])
extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs
- // TODO: make this null when finishing checkpoint
@transient
- val splits_ : Array[Split] = {
+ var splits_ : Array[Split] = {
val array = new Array[Split](rdds.map(_.splits.size).sum)
var pos = 0
for (rdd <- rdds; split <- rdd.splits) {
@@ -38,11 +33,9 @@ class UnionRDD[T: ClassManifest](
array
}
- // TODO: make this return checkpoint Hadoop RDDs split when checkpointed
override def splits = splits_
- // TODO: make this null when finishing checkpoint
- @transient var deps = {
+ @transient var deps_ = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for (rdd <- rdds) {
@@ -52,10 +45,21 @@ class UnionRDD[T: ClassManifest](
deps.toList
}
- override def dependencies = deps
+ override def dependencies = deps_
override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator()
- override def preferredLocations(s: Split): Seq[String] =
- s.asInstanceOf[UnionSplit[T]].preferredLocations()
+ override def preferredLocations(s: Split): Seq[String] = {
+ if (isCheckpointed) {
+ checkpointRDD.preferredLocations(s)
+ } else {
+ s.asInstanceOf[UnionSplit[T]].preferredLocations()
+ }
+ }
+
+ override protected def changeDependencies(newRDD: RDD[_]) {
+ deps_ = List(new OneToOneDependency(newRDD))
+ splits_ = newRDD.splits
+ rdds = null
+ }
}
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
new file mode 100644
index 0000000000..0e5ca7dc21
--- /dev/null
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -0,0 +1,116 @@
+package spark
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import java.io.File
+import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD}
+import spark.SparkContext._
+import storage.StorageLevel
+
+class CheckpointSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+ var checkpointDir: File = _
+
+ before {
+ checkpointDir = File.createTempFile("temp", "")
+ checkpointDir.delete()
+
+ sc = new SparkContext("local", "test")
+ sc.setCheckpointDir(checkpointDir.toString)
+ }
+
+ after {
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.master.port")
+
+ if (checkpointDir != null) {
+ checkpointDir.delete()
+ }
+ }
+
+ test("ParallelCollection") {
+ val parCollection = sc.makeRDD(1 to 4)
+ parCollection.checkpoint()
+ assert(parCollection.dependencies === Nil)
+ val result = parCollection.collect()
+ sleep(parCollection) // slightly extra time as loading classes for the first can take some time
+ assert(sc.objectFile[Int](parCollection.checkpointFile).collect() === result)
+ assert(parCollection.dependencies != Nil)
+ assert(parCollection.collect() === result)
+ }
+
+ test("BlockRDD") {
+ val blockId = "id"
+ val blockManager = SparkEnv.get.blockManager
+ blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
+ val blockRDD = new BlockRDD[String](sc, Array(blockId))
+ blockRDD.checkpoint()
+ val result = blockRDD.collect()
+ sleep(blockRDD)
+ assert(sc.objectFile[String](blockRDD.checkpointFile).collect() === result)
+ assert(blockRDD.dependencies != Nil)
+ assert(blockRDD.collect() === result)
+ }
+
+ test("RDDs with one-to-one dependencies") {
+ testCheckpointing(_.map(x => x.toString))
+ testCheckpointing(_.flatMap(x => 1 to x))
+ testCheckpointing(_.filter(_ % 2 == 0))
+ testCheckpointing(_.sample(false, 0.5, 0))
+ testCheckpointing(_.glom())
+ testCheckpointing(_.mapPartitions(_.map(_.toString)))
+ testCheckpointing(r => new MapPartitionsWithSplitRDD(r,
+ (i: Int, iter: Iterator[Int]) => iter.map(_.toString) ))
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), 1000)
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x), 1000)
+ testCheckpointing(_.pipe(Seq("cat")))
+ }
+
+ test("ShuffledRDD") {
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _))
+ }
+
+ test("UnionRDD") {
+ testCheckpointing(_.union(sc.makeRDD(5 to 6, 4)))
+ }
+
+ test("CartesianRDD") {
+ testCheckpointing(_.cartesian(sc.makeRDD(5 to 6, 4)), 1000)
+ }
+
+ test("CoalescedRDD") {
+ testCheckpointing(new CoalescedRDD(_, 2))
+ }
+
+ test("CoGroupedRDD") {
+ val rdd2 = sc.makeRDD(5 to 6, 4).map(x => (x % 2, 1))
+ testCheckpointing(rdd1 => rdd1.map(x => (x % 2, 1)).cogroup(rdd2))
+ testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2))
+ }
+
+ def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) {
+ val parCollection = sc.makeRDD(1 to 4, 4)
+ val operatedRDD = op(parCollection)
+ operatedRDD.checkpoint()
+ val parentRDD = operatedRDD.dependencies.head.rdd
+ val result = operatedRDD.collect()
+ sleep(operatedRDD)
+ //println(parentRDD + ", " + operatedRDD.dependencies.head.rdd )
+ assert(sc.objectFile[U](operatedRDD.checkpointFile).collect() === result)
+ assert(operatedRDD.dependencies.head.rdd != parentRDD)
+ assert(operatedRDD.collect() === result)
+ }
+
+ def sleep(rdd: RDD[_]) {
+ val startTime = System.currentTimeMillis()
+ val maxWaitTime = 5000
+ while(rdd.isCheckpointed == false && System.currentTimeMillis() < startTime + maxWaitTime) {
+ Thread.sleep(50)
+ }
+ assert(rdd.isCheckpointed === true, "Waiting for checkpoint to complete took more than " + maxWaitTime + " ms")
+ }
+}
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 37a0ff0947..8ac7c8451a 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -19,7 +19,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.master.port")
}
-
+
test("basic operations") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
@@ -70,10 +70,23 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
}
- test("checkpointing") {
+ test("basic checkpointing") {
+ import java.io.File
+ val checkpointDir = File.createTempFile("temp", "")
+ checkpointDir.delete()
+
sc = new SparkContext("local", "test")
- val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).flatMap(x => 1 to x).checkpoint()
- assert(rdd.collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
+ sc.setCheckpointDir(checkpointDir.toString)
+ val parCollection = sc.makeRDD(1 to 4)
+ val flatMappedRDD = parCollection.flatMap(x => 1 to x)
+ flatMappedRDD.checkpoint()
+ assert(flatMappedRDD.dependencies.head.rdd == parCollection)
+ val result = flatMappedRDD.collect()
+ Thread.sleep(1000)
+ assert(flatMappedRDD.dependencies.head.rdd != parCollection)
+ assert(flatMappedRDD.collect() === result)
+
+ checkpointDir.deleteOnExit()
}
test("basic caching") {
@@ -94,8 +107,8 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10)))
// Check that the narrow dependency is also specified correctly
- assert(coalesced1.dependencies.head.getParents(0).toList === List(0, 1, 2, 3, 4))
- assert(coalesced1.dependencies.head.getParents(1).toList === List(5, 6, 7, 8, 9))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9))
val coalesced2 = new CoalescedRDD(data, 3)
assert(coalesced2.collect().toList === (1 to 10).toList)