diff options
author | Tathagata Das <tathagata.das1565@gmail.com> | 2012-11-02 17:05:22 -0700 |
---|---|---|
committer | Tathagata Das <tathagata.das1565@gmail.com> | 2012-11-02 17:05:22 -0700 |
commit | 596154eabe51961733789a18a47067748fb72e8e (patch) | |
tree | b6baf2d1c6f8fa75625630c65fa19632cee2eec9 | |
parent | 3fb5c9ee24302edf02df130bd0dfd0463cf6c0a4 (diff) | |
parent | 34e569f40e184a6a4f21e9d79b0e8979d8f9541f (diff) | |
download | spark-596154eabe51961733789a18a47067748fb72e8e.tar.gz spark-596154eabe51961733789a18a47067748fb72e8e.tar.bz2 spark-596154eabe51961733789a18a47067748fb72e8e.zip |
Merge branch 'dev-checkpoint' into dev
22 files changed, 539 insertions, 164 deletions
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index e5bb639cfd..1f82bd3ab8 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -23,6 +23,7 @@ import spark.partial.BoundedDouble import spark.partial.PartialResult import spark.rdd._ import spark.SparkContext._ +import java.lang.ref.WeakReference /** * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. @@ -624,23 +625,22 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( } private[spark] -class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override val partitioner = prev.partitioner - override def compute(split: Split) = prev.iterator(split).map{case (k, v) => (k, f(v))} +class MappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => U) + extends RDD[(K, U)](prev.get) { + + override def splits = firstParent[(K, V)].splits + override val partitioner = firstParent[(K, V)].partitioner + override def compute(split: Split) = firstParent[(K, V)].iterator(split).map{case (k, v) => (k, f(v))} } private[spark] -class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U]) - extends RDD[(K, U)](prev.context) { - - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override val partitioner = prev.partitioner +class FlatMappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U]) + extends RDD[(K, U)](prev.get) { + override def splits = firstParent[(K, V)].splits + override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split) = { - prev.iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) } + firstParent[(K, V)].iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) } } } diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9b57ae3b4f..9725017b61 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -22,13 +22,13 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( } private[spark] class ParallelCollection[T: ClassManifest]( - sc: SparkContext, + @transient sc : SparkContext, @transient data: Seq[T], numSlices: Int) - extends RDD[T](sc) { + extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split - // instead. + // instead. UPDATE: With the new changes to enable checkpointing, this an be done. @transient val splits_ = { @@ -41,8 +41,6 @@ private[spark] class ParallelCollection[T: ClassManifest]( override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def preferredLocations(s: Split): Seq[String] = Nil - - override val dependencies: List[Dependency[_]] = Nil } private object ParallelCollection { diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 338dff4061..7b59a6f09e 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -1,8 +1,7 @@ package spark -import java.io.EOFException +import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream} import java.net.URL -import java.io.ObjectInputStream import java.util.concurrent.atomic.AtomicLong import java.util.Random import java.util.Date @@ -13,6 +12,7 @@ import scala.collection.Map import scala.collection.mutable.HashMap import scala.collection.JavaConversions.mapAsScalaMap +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text @@ -72,7 +72,14 @@ import SparkContext._ * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details * on RDD internals. */ -abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serializable { +abstract class RDD[T: ClassManifest]( + @transient var sc: SparkContext, + var dependencies_ : List[Dependency[_]] + ) extends Serializable { + + + def this(@transient oneParent: RDD[_]) = + this(oneParent.context , List(new OneToOneDependency(oneParent))) // Methods that must be implemented by subclasses: @@ -83,10 +90,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def compute(split: Split): Iterator[T] /** How this RDD depends on any parent RDDs. */ - @transient val dependencies: List[Dependency[_]] + def dependencies: List[Dependency[_]] = dependencies_ - // Methods available on all RDDs: - /** Record user function generating this RDD. */ private[spark] val origin = Utils.getSparkCallSite @@ -94,7 +99,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial val partitioner: Option[Partitioner] = None /** Optionally overridden by subclasses to specify placement preferences. */ - def preferredLocations(split: Split): Seq[String] = Nil + def preferredLocations(split: Split): Seq[String] = { + if (isCheckpointed) { + checkpointRDD.preferredLocations(split) + } else { + Nil + } + } /** The [[spark.SparkContext]] that this RDD was created on. */ def context = sc @@ -106,8 +117,28 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE - - /** + + /** Returns the first parent RDD */ + private[spark] def firstParent[U: ClassManifest] = { + dependencies.head.rdd.asInstanceOf[RDD[U]] + } + + /** Returns the `i` th parent RDD */ + private[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] + + // Variables relating to checkpointing + val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD + var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing + var isCheckpointInProgress = false // set to true when checkpointing is in progress + var isCheckpointed = false // set to true after checkpointing is completed + + var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed + var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file + var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD + + // Methods available on all RDDs: + + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. */ @@ -129,33 +160,95 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - - private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = { - if (!level.useDisk && level.replication < 2) { - throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") - } - - // This is a hack. Ideally this should re-use the code used by the CacheTracker - // to generate the key. - def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index) - - persist(level) - sc.runJob(this, (iter: Iterator[T]) => {} ) - - val p = this.partitioner - - new BlockRDD[T](sc, splits.map(getSplitKey).toArray) { - override val partitioner = p + + /** + * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` + * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. + * This is used to truncate very long lineages. In the current implementation, Spark will save + * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. + * Hence, it is strongly recommended to use checkpoint() on RDDs when + * (i) Checkpoint() is called before the any job has been executed on this RDD. + * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will + * require recomputation. + */ + protected[spark] def checkpoint() { + synchronized { + if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) { + // do nothing + } else if (isCheckpointable) { + shouldCheckpoint = true + } else { + throw new Exception(this + " cannot be checkpointed") + } } } - + + /** + * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after a job + * using this RDD has completed (therefore the RDD has been materialized and + * potentially stored in memory). In case this RDD is not marked for checkpointing, + * doCheckpoint() is called recursively on the parent RDDs. + */ + private[spark] def doCheckpoint() { + val startCheckpoint = synchronized { + if (isCheckpointable && shouldCheckpoint && !isCheckpointInProgress) { + isCheckpointInProgress = true + true + } else { + false + } + } + + if (startCheckpoint) { + val rdd = this + val env = SparkEnv.get + + // Spawn a new thread to do the checkpoint as it takes sometime to write the RDD to file + val th = new Thread() { + override def run() { + // Save the RDD to a file, create a new HadoopRDD from it, + // and change the dependencies from the original parents to the new RDD + SparkEnv.set(env) + rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString + rdd.saveAsObjectFile(checkpointFile) + rdd.synchronized { + rdd.checkpointRDD = context.objectFile[T](checkpointFile) + rdd.checkpointRDDSplits = rdd.checkpointRDD.splits + rdd.changeDependencies(rdd.checkpointRDD) + rdd.shouldCheckpoint = false + rdd.isCheckpointInProgress = false + rdd.isCheckpointed = true + } + } + } + th.start() + } else { + // Recursively call doCheckpoint() to perform checkpointing on parent RDD if they are marked + dependencies.foreach(_.rdd.doCheckpoint()) + } + } + + /** + * Changes the dependencies of this RDD from its original parents to the new [[spark.rdd.HadoopRDD]] + * (`newRDD`) created from the checkpoint file. This method must ensure that all references + * to the original parent RDDs must be removed to enable the parent RDDs to be garbage + * collected. Subclasses of RDD may override this method for implementing their own changing + * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + */ + protected def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD)) + } + /** * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. * This should ''not'' be called by users directly, but is available for implementors of custom * subclasses of RDD. */ final def iterator(split: Split): Iterator[T] = { - if (storageLevel != StorageLevel.NONE) { + if (isCheckpointed) { + // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original + checkpointRDD.iterator(checkpointRDDSplits(split.index)) + } else if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) } else { compute(split) @@ -495,4 +588,19 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) } + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + synchronized { + oos.defaultWriteObject() + } + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + synchronized { + ois.defaultReadObject() + } + } + } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 0d37075ef3..79ceab5f4f 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -3,6 +3,7 @@ package spark import java.io._ import java.util.concurrent.atomic.AtomicInteger import java.net.{URI, URLClassLoader} +import java.lang.ref.WeakReference import scala.collection.Map import scala.collection.generic.Growable @@ -187,6 +188,8 @@ class SparkContext( private var dagScheduler = new DAGScheduler(taskScheduler) + private[spark] var checkpointDir: String = null + // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ @@ -518,6 +521,7 @@ class SparkContext( val start = System.nanoTime val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") + rdd.doCheckpoint() result } @@ -574,6 +578,24 @@ class SparkContext( return f } + /** + * Set the directory under which RDDs are going to be checkpointed. This method will + * create this directory and will throw an exception of the path already exists (to avoid + * overwriting existing files may be overwritten). The directory will be deleted on exit + * if indicated. + */ + def setCheckpointDir(dir: String, deleteOnExit: Boolean = false) { + val path = new Path(dir) + val fs = path.getFileSystem(new Configuration()) + if (fs.exists(path)) { + throw new Exception("Checkpoint directory '" + path + "' already exists.") + } else { + fs.mkdirs(path) + if (deleteOnExit) fs.deleteOnExit(path) + } + checkpointDir = dir + } + /** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */ def defaultParallelism: Int = taskScheduler.defaultParallelism @@ -695,6 +717,9 @@ object SparkContext { /** Find the JAR that contains the class of a particular object */ def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass) + + implicit def rddToWeakRefRDD[T: ClassManifest](rdd: RDD[T]) = new WeakReference(rdd) + } diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index cb73976aed..f4c3f99011 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -14,7 +14,7 @@ private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split private[spark] class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) - extends RDD[T](sc) { + extends RDD[T](sc, Nil) { @transient val splits_ = (0 until blockIds.size).map(i => { @@ -41,9 +41,12 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St } } - override def preferredLocations(split: Split) = - locations_(split.asInstanceOf[BlockRDDSplit].blockId) - - override val dependencies: List[Dependency[_]] = Nil + override def preferredLocations(split: Split) = { + if (isCheckpointed) { + checkpointRDD.preferredLocations(split) + } else { + locations_(split.asInstanceOf[BlockRDDSplit].blockId) + } + } } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 7c354b6b2e..458ad38d55 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,9 +1,7 @@ package spark.rdd -import spark.NarrowDependency -import spark.RDD -import spark.SparkContext -import spark.Split +import spark._ +import java.lang.ref.WeakReference private[spark] class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { @@ -13,15 +11,15 @@ class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with private[spark] class CartesianRDD[T: ClassManifest, U:ClassManifest]( sc: SparkContext, - rdd1: RDD[T], - rdd2: RDD[U]) - extends RDD[Pair[T, U]](sc) + var rdd1 : RDD[T], + var rdd2 : RDD[U]) + extends RDD[Pair[T, U]](sc, Nil) with Serializable { - + val numSplitsInRdd2 = rdd2.splits.size - + @transient - val splits_ = { + var splits_ = { // create the cross product split val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { @@ -34,16 +32,20 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( override def splits = splits_ override def preferredLocations(split: Split) = { - val currSplit = split.asInstanceOf[CartesianSplit] - rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) + if (isCheckpointed) { + checkpointRDD.preferredLocations(split) + } else { + val currSplit = split.asInstanceOf[CartesianSplit] + rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) + } } override def compute(split: Split) = { val currSplit = split.asInstanceOf[CartesianSplit] for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y) } - - override val dependencies = List( + + var deps_ = List( new NarrowDependency(rdd1) { def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) }, @@ -51,4 +53,13 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2) } ) + + override def dependencies = deps_ + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + rdd1 = null + rdd2 = null + } } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 50bec9e63b..a313ebcbe8 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -30,13 +30,13 @@ private[spark] class CoGroupAggregator { (b1, b2) => b1 ++ b2 }) with Serializable -class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) - extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging { +class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) + extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging { val aggr = new CoGroupAggregator - + @transient - override val dependencies = { + var deps_ = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) @@ -50,9 +50,11 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) } deps.toList } - + + override def dependencies = deps_ + @transient - val splits_ : Array[Split] = { + var splits_ : Array[Split] = { val firstRdd = rdds.head val array = new Array[Split](part.numPartitions) for (i <- 0 until array.size) { @@ -72,8 +74,6 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) override val partitioner = Some(part) - override def preferredLocations(s: Split) = Nil - override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] val numRdds = split.deps.size @@ -101,4 +101,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) } map.iterator } + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + rdds = null + } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 0967f4f5df..5b5f72ddeb 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -1,8 +1,7 @@ package spark.rdd -import spark.NarrowDependency -import spark.RDD -import spark.Split +import spark._ +import java.lang.ref.WeakReference private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split @@ -14,10 +13,12 @@ private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) exten * This transformation is useful when an RDD with many partitions gets filtered into a smaller one, * or to avoid having a large number of small tasks when processing a directory with many files. */ -class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int) - extends RDD[T](prev.context) { +class CoalescedRDD[T: ClassManifest]( + var prev: RDD[T], + maxPartitions: Int) + extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs - @transient val splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val prevSplits = prev.splits if (prevSplits.length < maxPartitions) { prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) } @@ -34,14 +35,22 @@ class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int) override def compute(split: Split): Iterator[T] = { split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { - parentSplit => prev.iterator(parentSplit) + parentSplit => firstParent[T].iterator(parentSplit) } } - val dependencies = List( + var deps_ : List[Dependency[_]] = List( new NarrowDependency(prev) { def getParents(id: Int): Seq[Int] = splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index) } ) + + override def dependencies = deps_ + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD)) + splits_ = newRDD.splits + prev = null + } } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index dfe9dc73f3..1370cf6faf 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -3,10 +3,14 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] -class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).filter(f) +class FilteredRDD[T: ClassManifest]( + prev: WeakReference[RDD[T]], + f: T => Boolean) + extends RDD[T](prev.get) { + + override def splits = firstParent[T].splits + override def compute(split: Split) = firstParent[T].iterator(split).filter(f) }
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 3534dc8057..6b2cc67568 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -3,14 +3,14 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + prev: WeakReference[RDD[T]], f: T => TraversableOnce[U]) - extends RDD[U](prev.context) { + extends RDD[U](prev.get) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).flatMap(f) + override def splits = firstParent[T].splits + override def compute(split: Split) = firstParent[T].iterator(split).flatMap(f) } diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index e30564f2da..0f0b6ab0ff 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -3,10 +3,11 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] -class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator +class GlommedRDD[T: ClassManifest](prev: WeakReference[RDD[T]]) + extends RDD[Array[T]](prev.get) { + override def splits = firstParent[T].splits + override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator }
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index bf29a1f075..19ed56d9c0 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -46,7 +46,7 @@ class HadoopRDD[K, V]( keyClass: Class[K], valueClass: Class[V], minSplits: Int) - extends RDD[(K, V)](sc) { + extends RDD[(K, V)](sc, Nil) { // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it val confBroadcast = sc.broadcast(new SerializableWritable(conf)) @@ -115,6 +115,6 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } - - override val dependencies: List[Dependency[_]] = Nil + + override val isCheckpointable = false } diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index a904ef62c3..b04f56cfcc 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -3,17 +3,17 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + prev: WeakReference[RDD[T]], f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false) - extends RDD[U](prev.context) { + extends RDD[U](prev.get) { - override val partitioner = if (preservesPartitioning) prev.partitioner else None + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = f(prev.iterator(split)) + override def splits = firstParent[T].splits + override def compute(split: Split) = f(firstParent[T].iterator(split)) }
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index adc541694e..7a4b6ffb03 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -3,6 +3,7 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference /** * A variant of the MapPartitionsRDD that passes the split index into the @@ -11,11 +12,10 @@ import spark.Split */ private[spark] class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + prev: WeakReference[RDD[T]], f: (Int, Iterator[T]) => Iterator[U]) - extends RDD[U](prev.context) { + extends RDD[U](prev.get) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = f(split.index, prev.iterator(split)) + override def splits = firstParent[T].splits + override def compute(split: Split) = f(split.index, firstParent[T].iterator(split)) }
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 59bedad8ef..8fa1872e0a 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -3,14 +3,14 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + prev: WeakReference[RDD[T]], f: T => U) - extends RDD[U](prev.context) { - - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).map(f) + extends RDD[U](prev.get) { + + override def splits = firstParent[T].splits + override def compute(split: Split) = firstParent[T].iterator(split).map(f) }
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index 7a1a0fb87d..2875abb2db 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -23,11 +23,12 @@ class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit } class NewHadoopRDD[K, V]( - sc: SparkContext, + sc : SparkContext, inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], valueClass: Class[V], + keyClass: Class[K], + valueClass: Class[V], @transient conf: Configuration) - extends RDD[(K, V)](sc) + extends RDD[(K, V)](sc, Nil) with HadoopMapReduceUtil { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it @@ -93,5 +94,5 @@ class NewHadoopRDD[K, V]( theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } - override val dependencies: List[Dependency[_]] = Nil + override val isCheckpointable = false } diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 98ea0c92d6..d9293a9d1a 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -12,6 +12,7 @@ import spark.OneToOneDependency import spark.RDD import spark.SparkEnv import spark.Split +import java.lang.ref.WeakReference /** @@ -19,18 +20,18 @@ import spark.Split * (printing them one per line) and returns the output as a collection of strings. */ class PipedRDD[T: ClassManifest]( - parent: RDD[T], command: Seq[String], envVars: Map[String, String]) - extends RDD[String](parent.context) { + prev: WeakReference[RDD[T]], + command: Seq[String], + envVars: Map[String, String]) + extends RDD[String](prev.get) { - def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map()) + def this(prev: WeakReference[RDD[T]], command: Seq[String]) = this(prev, command, Map()) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command)) + def this(prev: WeakReference[RDD[T]], command: String) = this(prev, PipedRDD.tokenize(command)) - override def splits = parent.splits - - override val dependencies = List(new OneToOneDependency(parent)) + override def splits = firstParent[T].splits override def compute(split: Split): Iterator[String] = { val pb = new ProcessBuilder(command) @@ -55,7 +56,7 @@ class PipedRDD[T: ClassManifest]( override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) - for (elem <- parent.iterator(split)) { + for (elem <- firstParent[T].iterator(split)) { out.println(elem) } out.close() diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 87a5268f27..f273f257f8 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -7,6 +7,7 @@ import cern.jet.random.engine.DRand import spark.RDD import spark.OneToOneDependency import spark.Split +import java.lang.ref.WeakReference private[spark] class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable { @@ -14,24 +15,22 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali } class SampledRDD[T: ClassManifest]( - prev: RDD[T], + prev: WeakReference[RDD[T]], withReplacement: Boolean, frac: Double, seed: Int) - extends RDD[T](prev.context) { + extends RDD[T](prev.get) { @transient val splits_ = { val rg = new Random(seed) - prev.splits.map(x => new SampledRDDSplit(x, rg.nextInt)) + firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } override def splits = splits_.asInstanceOf[Array[Split]] - override val dependencies = List(new OneToOneDependency(prev)) - override def preferredLocations(split: Split) = - prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) + firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) override def compute(splitIn: Split) = { val split = splitIn.asInstanceOf[SampledRDDSplit] @@ -39,7 +38,7 @@ class SampledRDD[T: ClassManifest]( // For large datasets, the expected number of occurrences of each element in a sample with // replacement is Poisson(frac). We use that to get a count for each element. val poisson = new Poisson(frac, new DRand(split.seed)) - prev.iterator(split.prev).flatMap { element => + firstParent[T].iterator(split.prev).flatMap { element => val count = poisson.nextInt() if (count == 0) { Iterator.empty // Avoid object allocation when we return 0 items, which is quite often @@ -49,7 +48,7 @@ class SampledRDD[T: ClassManifest]( } } else { // Sampling without replacement val rand = new Random(split.seed) - prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac)) + firstParent[T].iterator(split.prev).filter(x => (rand.nextDouble <= frac)) } } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 145e419c53..31774585f4 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -5,6 +5,7 @@ import spark.RDD import spark.ShuffleDependency import spark.SparkEnv import spark.Split +import java.lang.ref.WeakReference private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { override val index = idx @@ -19,22 +20,24 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { * @tparam V the value class. */ class ShuffledRDD[K, V]( - @transient parent: RDD[(K, V)], - part: Partitioner) extends RDD[(K, V)](parent.context) { + @transient prev: WeakReference[RDD[(K, V)]], + part: Partitioner) + extends RDD[(K, V)](prev.get.context, List(new ShuffleDependency(prev.get, part))) { override val partitioner = Some(part) @transient - val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) + var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) override def splits = splits_ - override def preferredLocations(split: Split) = Nil - - val dep = new ShuffleDependency(parent, part) - override val dependencies = List(dep) - override def compute(split: Split): Iterator[(K, V)] = { - SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index) + val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId + SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) + } + + override def changeDependencies(newRDD: RDD[_]) { + dependencies_ = Nil + splits_ = newRDD.splits } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index f0b9225f7c..643a174160 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -2,11 +2,8 @@ package spark.rdd import scala.collection.mutable.ArrayBuffer -import spark.Dependency -import spark.RangeDependency -import spark.RDD -import spark.SparkContext -import spark.Split +import spark._ +import java.lang.ref.WeakReference private[spark] class UnionSplit[T: ClassManifest]( idx: Int, @@ -22,12 +19,11 @@ private[spark] class UnionSplit[T: ClassManifest]( class UnionRDD[T: ClassManifest]( sc: SparkContext, - @transient rdds: Seq[RDD[T]]) - extends RDD[T](sc) - with Serializable { - + @transient var rdds: Seq[RDD[T]]) + extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs + @transient - val splits_ : Array[Split] = { + var splits_ : Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { @@ -39,19 +35,31 @@ class UnionRDD[T: ClassManifest]( override def splits = splits_ - @transient - override val dependencies = { + @transient var deps_ = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) + deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) pos += rdd.splits.size } deps.toList } - + + override def dependencies = deps_ + override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() - override def preferredLocations(s: Split): Seq[String] = - s.asInstanceOf[UnionSplit[T]].preferredLocations() + override def preferredLocations(s: Split): Seq[String] = { + if (isCheckpointed) { + checkpointRDD.preferredLocations(s) + } else { + s.asInstanceOf[UnionSplit[T]].preferredLocations() + } + } + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD)) + splits_ = newRDD.splits + rdds = null + } } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala new file mode 100644 index 0000000000..57dc43ddac --- /dev/null +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -0,0 +1,185 @@ +package spark + +import org.scalatest.{BeforeAndAfter, FunSuite} +import java.io.File +import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD} +import spark.SparkContext._ +import storage.StorageLevel +import java.util.concurrent.Semaphore + +class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { + initLogging() + + var sc: SparkContext = _ + var checkpointDir: File = _ + + before { + checkpointDir = File.createTempFile("temp", "") + checkpointDir.delete() + + sc = new SparkContext("local", "test") + sc.setCheckpointDir(checkpointDir.toString) + } + + after { + if (sc != null) { + sc.stop() + sc = null + } + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + + if (checkpointDir != null) { + checkpointDir.delete() + } + } + + test("ParallelCollection") { + val parCollection = sc.makeRDD(1 to 4) + parCollection.checkpoint() + assert(parCollection.dependencies === Nil) + val result = parCollection.collect() + sleep(parCollection) // slightly extra time as loading classes for the first can take some time + assert(sc.objectFile[Int](parCollection.checkpointFile).collect() === result) + assert(parCollection.dependencies != Nil) + assert(parCollection.collect() === result) + } + + test("BlockRDD") { + val blockId = "id" + val blockManager = SparkEnv.get.blockManager + blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) + val blockRDD = new BlockRDD[String](sc, Array(blockId)) + blockRDD.checkpoint() + val result = blockRDD.collect() + sleep(blockRDD) + assert(sc.objectFile[String](blockRDD.checkpointFile).collect() === result) + assert(blockRDD.dependencies != Nil) + assert(blockRDD.collect() === result) + } + + test("RDDs with one-to-one dependencies") { + testCheckpointing(_.map(x => x.toString)) + testCheckpointing(_.flatMap(x => 1 to x)) + testCheckpointing(_.filter(_ % 2 == 0)) + testCheckpointing(_.sample(false, 0.5, 0)) + testCheckpointing(_.glom()) + testCheckpointing(_.mapPartitions(_.map(_.toString))) + testCheckpointing(r => new MapPartitionsWithSplitRDD(r, + (i: Int, iter: Iterator[Int]) => iter.map(_.toString) )) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), 1000) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x), 1000) + testCheckpointing(_.pipe(Seq("cat"))) + } + + test("ShuffledRDD") { + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _)) + } + + test("UnionRDD") { + testCheckpointing(_.union(sc.makeRDD(5 to 6, 4))) + } + + test("CartesianRDD") { + testCheckpointing(_.cartesian(sc.makeRDD(5 to 6, 4)), 1000) + } + + test("CoalescedRDD") { + testCheckpointing(new CoalescedRDD(_, 2)) + } + + test("CoGroupedRDD") { + val rdd2 = sc.makeRDD(5 to 6, 4).map(x => (x % 2, 1)) + testCheckpointing(rdd1 => rdd1.map(x => (x % 2, 1)).cogroup(rdd2)) + testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2)) + } + + /** + * This test forces two ResultTasks of the same job to be launched before and after + * the checkpointing of job's RDD is completed. + */ + test("Threading - ResultTasks") { + val op1 = (parCollection: RDD[Int]) => { + parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) }) + } + val op2 = (firstRDD: RDD[(Int, Int)]) => { + firstRDD.map(x => { println("2nd map running on " + x); Thread.sleep(500); x }) + } + testThreading(op1, op2) + } + + /** + * This test forces two ShuffleMapTasks of the same job to be launched before and after + * the checkpointing of job's RDD is completed. + */ + test("Threading - ShuffleMapTasks") { + val op1 = (parCollection: RDD[Int]) => { + parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) }) + } + val op2 = (firstRDD: RDD[(Int, Int)]) => { + firstRDD.groupByKey(2).map(x => { println("2nd map running on " + x); Thread.sleep(500); x }) + } + testThreading(op1, op2) + } + + + def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { + val parCollection = sc.makeRDD(1 to 4, 4) + val operatedRDD = op(parCollection) + operatedRDD.checkpoint() + val parentRDD = operatedRDD.dependencies.head.rdd + val result = operatedRDD.collect() + sleep(operatedRDD) + //println(parentRDD + ", " + operatedRDD.dependencies.head.rdd ) + assert(sc.objectFile[U](operatedRDD.checkpointFile).collect() === result) + assert(operatedRDD.dependencies.head.rdd != parentRDD) + assert(operatedRDD.collect() === result) + } + + def testThreading[U: ClassManifest, V: ClassManifest](op1: (RDD[Int]) => RDD[U], op2: (RDD[U]) => RDD[V]) { + + val parCollection = sc.makeRDD(1 to 2, 2) + + // This is the RDD that is to be checkpointed + val firstRDD = op1(parCollection) + val parentRDD = firstRDD.dependencies.head.rdd + firstRDD.checkpoint() + + // This the RDD that uses firstRDD. This is designed to launch a + // ShuffleMapTask that uses firstRDD. + val secondRDD = op2(firstRDD) + + // Starting first job, to initiate the checkpointing + logInfo("\nLaunching 1st job to initiate checkpointing\n") + firstRDD.collect() + + // Checkpointing has started but not completed yet + Thread.sleep(100) + assert(firstRDD.dependencies.head.rdd === parentRDD) + + // Starting second job; first task of this job will be + // launched _before_ firstRDD is marked as checkpointed + // and the second task will be launched _after_ firstRDD + // is marked as checkpointed + logInfo("\nLaunching 2nd job that is designed to launch tasks " + + "before and after checkpointing is complete\n") + val result = secondRDD.collect() + + // Check whether firstRDD has been successfully checkpointed + assert(firstRDD.dependencies.head.rdd != parentRDD) + + logInfo("\nRecomputing 2nd job to verify the results of the previous computation\n") + // Check whether the result in the previous job was correct or not + val correctResult = secondRDD.collect() + assert(result === correctResult) + } + + def sleep(rdd: RDD[_]) { + val startTime = System.currentTimeMillis() + val maxWaitTime = 5000 + while(rdd.isCheckpointed == false && System.currentTimeMillis() < startTime + maxWaitTime) { + Thread.sleep(50) + } + assert(rdd.isCheckpointed === true, "Waiting for checkpoint to complete took more than " + maxWaitTime + " ms") + } +} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 37a0ff0947..8ac7c8451a 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -19,7 +19,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.master.port") } - + test("basic operations") { sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) @@ -70,10 +70,23 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } - test("checkpointing") { + test("basic checkpointing") { + import java.io.File + val checkpointDir = File.createTempFile("temp", "") + checkpointDir.delete() + sc = new SparkContext("local", "test") - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).flatMap(x => 1 to x).checkpoint() - assert(rdd.collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) + sc.setCheckpointDir(checkpointDir.toString) + val parCollection = sc.makeRDD(1 to 4) + val flatMappedRDD = parCollection.flatMap(x => 1 to x) + flatMappedRDD.checkpoint() + assert(flatMappedRDD.dependencies.head.rdd == parCollection) + val result = flatMappedRDD.collect() + Thread.sleep(1000) + assert(flatMappedRDD.dependencies.head.rdd != parCollection) + assert(flatMappedRDD.collect() === result) + + checkpointDir.deleteOnExit() } test("basic caching") { @@ -94,8 +107,8 @@ class RDDSuite extends FunSuite with BeforeAndAfter { List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) // Check that the narrow dependency is also specified correctly - assert(coalesced1.dependencies.head.getParents(0).toList === List(0, 1, 2, 3, 4)) - assert(coalesced1.dependencies.head.getParents(1).toList === List(5, 6, 7, 8, 9)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9)) val coalesced2 = new CoalescedRDD(data, 3) assert(coalesced2.collect().toList === (1 to 10).toList) |