aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala154
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala105
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala97
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala16
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala173
6 files changed, 452 insertions, 99 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 7e75e7083a..4b90fbdf0c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -142,8 +142,8 @@ final class EMLDAOptimizer extends LDAOptimizer {
this.k = k
this.vocabSize = docs.take(1).head._2.size
this.checkpointInterval = lda.getCheckpointInterval
- this.graphCheckpointer = new
- PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval)
+ this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
+ checkpointInterval, graph.vertices.sparkContext)
this.globalTopicTotals = computeGlobalTopicTotals()
this
}
@@ -188,7 +188,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
// Update the vertex descriptors with the new counts.
val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
graph = newGraph
- graphCheckpointer.updateGraph(newGraph)
+ graphCheckpointer.update(newGraph)
globalTopicTotals = computeGlobalTopicTotals()
this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
new file mode 100644
index 0000000000..72d3aabc9b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.{Path, FileSystem}
+
+import org.apache.spark.{SparkContext, Logging}
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs
+ * (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to
+ * the distributed data type (RDD, Graph, etc.).
+ *
+ * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing,
+ * as well as unpersisting and removing checkpoint files.
+ *
+ * Users should call update() when a new Dataset has been created,
+ * before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are
+ * responsible for materializing the Dataset to ensure that persisting and checkpointing actually
+ * occur.
+ *
+ * When update() is called, this does the following:
+ * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets.
+ * - Unpersist Datasets from queue until there are at most 3 persisted Datasets.
+ * - If using checkpointing and the checkpoint interval has been reached,
+ * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets.
+ * - Remove older checkpoints.
+ *
+ * WARNINGS:
+ * - This class should NOT be copied (since copies may conflict on which Datasets should be
+ * checkpointed).
+ * - This class removes checkpoint files once later Datasets have been checkpointed.
+ * However, references to the older Datasets will still return isCheckpointed = true.
+ *
+ * @param checkpointInterval Datasets will be checkpointed at this interval
+ * @param sc SparkContext for the Datasets given to this checkpointer
+ * @tparam T Dataset type, such as RDD[Double]
+ */
+private[mllib] abstract class PeriodicCheckpointer[T](
+ val checkpointInterval: Int,
+ val sc: SparkContext) extends Logging {
+
+ /** FIFO queue of past checkpointed Datasets */
+ private val checkpointQueue = mutable.Queue[T]()
+
+ /** FIFO queue of past persisted Datasets */
+ private val persistedQueue = mutable.Queue[T]()
+
+ /** Number of times [[update()]] has been called */
+ private var updateCount = 0
+
+ /**
+ * Update with a new Dataset. Handle persistence and checkpointing as needed.
+ * Since this handles persistence and checkpointing, this should be called before the Dataset
+ * has been materialized.
+ *
+ * @param newData New Dataset created from previous Datasets in the lineage.
+ */
+ def update(newData: T): Unit = {
+ persist(newData)
+ persistedQueue.enqueue(newData)
+ // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class:
+ // Users should call [[update()]] when a new Dataset has been created,
+ // before the Dataset has been materialized.
+ while (persistedQueue.size > 3) {
+ val dataToUnpersist = persistedQueue.dequeue()
+ unpersist(dataToUnpersist)
+ }
+ updateCount += 1
+
+ // Handle checkpointing (after persisting)
+ if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
+ // Add new checkpoint before removing old checkpoints.
+ checkpoint(newData)
+ checkpointQueue.enqueue(newData)
+ // Remove checkpoints before the latest one.
+ var canDelete = true
+ while (checkpointQueue.size > 1 && canDelete) {
+ // Delete the oldest checkpoint only if the next checkpoint exists.
+ if (isCheckpointed(checkpointQueue.head)) {
+ removeCheckpointFile()
+ } else {
+ canDelete = false
+ }
+ }
+ }
+ }
+
+ /** Checkpoint the Dataset */
+ protected def checkpoint(data: T): Unit
+
+ /** Return true iff the Dataset is checkpointed */
+ protected def isCheckpointed(data: T): Boolean
+
+ /**
+ * Persist the Dataset.
+ * Note: This should handle checking the current [[StorageLevel]] of the Dataset.
+ */
+ protected def persist(data: T): Unit
+
+ /** Unpersist the Dataset */
+ protected def unpersist(data: T): Unit
+
+ /** Get list of checkpoint files for this given Dataset */
+ protected def getCheckpointFiles(data: T): Iterable[String]
+
+ /**
+ * Call this at the end to delete any remaining checkpoint files.
+ */
+ def deleteAllCheckpoints(): Unit = {
+ while (checkpointQueue.nonEmpty) {
+ removeCheckpointFile()
+ }
+ }
+
+ /**
+ * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
+ * This prints a warning but does not fail if the files cannot be removed.
+ */
+ private def removeCheckpointFile(): Unit = {
+ val old = checkpointQueue.dequeue()
+ // Since the old checkpoint is not deleted by Spark, we manually delete it.
+ val fs = FileSystem.get(sc.hadoopConfiguration)
+ getCheckpointFiles(old).foreach { checkpointFile =>
+ try {
+ fs.delete(new Path(checkpointFile), true)
+ } catch {
+ case e: Exception =>
+ logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
+ checkpointFile)
+ }
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
index 6e5dd119dd..11a059536c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
@@ -17,11 +17,7 @@
package org.apache.spark.mllib.impl
-import scala.collection.mutable
-
-import org.apache.hadoop.fs.{Path, FileSystem}
-
-import org.apache.spark.Logging
+import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph
import org.apache.spark.storage.StorageLevel
@@ -31,12 +27,12 @@ import org.apache.spark.storage.StorageLevel
* Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
* unpersisting and removing checkpoint files.
*
- * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created,
+ * Users should call update() when a new graph has been created,
* before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are
* responsible for materializing the graph to ensure that persisting and checkpointing actually
* occur.
*
- * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following:
+ * When update() is called, this does the following:
* - Persist new graph (if not yet persisted), and put in queue of persisted graphs.
* - Unpersist graphs from queue until there are at most 3 persisted graphs.
* - If using checkpointing and the checkpoint interval has been reached,
@@ -52,7 +48,7 @@ import org.apache.spark.storage.StorageLevel
* Example usage:
* {{{
* val (graph1, graph2, graph3, ...) = ...
- * val cp = new PeriodicGraphCheckpointer(graph1, dir, 2)
+ * val cp = new PeriodicGraphCheckpointer(2, sc)
* graph1.vertices.count(); graph1.edges.count()
* // persisted: graph1
* cp.updateGraph(graph2)
@@ -73,99 +69,30 @@ import org.apache.spark.storage.StorageLevel
* // checkpointed: graph4
* }}}
*
- * @param currentGraph Initial graph
* @param checkpointInterval Graphs will be checkpointed at this interval
* @tparam VD Vertex descriptor type
* @tparam ED Edge descriptor type
*
- * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib.
+ * TODO: Move this out of MLlib?
*/
private[mllib] class PeriodicGraphCheckpointer[VD, ED](
- var currentGraph: Graph[VD, ED],
- val checkpointInterval: Int) extends Logging {
-
- /** FIFO queue of past checkpointed RDDs */
- private val checkpointQueue = mutable.Queue[Graph[VD, ED]]()
-
- /** FIFO queue of past persisted RDDs */
- private val persistedQueue = mutable.Queue[Graph[VD, ED]]()
-
- /** Number of times [[updateGraph()]] has been called */
- private var updateCount = 0
-
- /**
- * Spark Context for the Graphs given to this checkpointer.
- * NOTE: This code assumes that only one SparkContext is used for the given graphs.
- */
- private val sc = currentGraph.vertices.sparkContext
+ checkpointInterval: Int,
+ sc: SparkContext)
+ extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) {
- updateGraph(currentGraph)
+ override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint()
- /**
- * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed.
- * Since this handles persistence and checkpointing, this should be called before the graph
- * has been materialized.
- *
- * @param newGraph New graph created from previous graphs in the lineage.
- */
- def updateGraph(newGraph: Graph[VD, ED]): Unit = {
- if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) {
- newGraph.persist()
- }
- persistedQueue.enqueue(newGraph)
- // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class:
- // Users should call [[updateGraph()]] when a new graph has been created,
- // before the graph has been materialized.
- while (persistedQueue.size > 3) {
- val graphToUnpersist = persistedQueue.dequeue()
- graphToUnpersist.unpersist(blocking = false)
- }
- updateCount += 1
+ override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed
- // Handle checkpointing (after persisting)
- if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
- // Add new checkpoint before removing old checkpoints.
- newGraph.checkpoint()
- checkpointQueue.enqueue(newGraph)
- // Remove checkpoints before the latest one.
- var canDelete = true
- while (checkpointQueue.size > 1 && canDelete) {
- // Delete the oldest checkpoint only if the next checkpoint exists.
- if (checkpointQueue.get(1).get.isCheckpointed) {
- removeCheckpointFile()
- } else {
- canDelete = false
- }
- }
+ override protected def persist(data: Graph[VD, ED]): Unit = {
+ if (data.vertices.getStorageLevel == StorageLevel.NONE) {
+ data.persist()
}
}
- /**
- * Call this at the end to delete any remaining checkpoint files.
- */
- def deleteAllCheckpoints(): Unit = {
- while (checkpointQueue.size > 0) {
- removeCheckpointFile()
- }
- }
+ override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false)
- /**
- * Dequeue the oldest checkpointed Graph, and remove its checkpoint files.
- * This prints a warning but does not fail if the files cannot be removed.
- */
- private def removeCheckpointFile(): Unit = {
- val old = checkpointQueue.dequeue()
- // Since the old checkpoint is not deleted by Spark, we manually delete it.
- val fs = FileSystem.get(sc.hadoopConfiguration)
- old.getCheckpointFiles.foreach { checkpointFile =>
- try {
- fs.delete(new Path(checkpointFile), true)
- } catch {
- case e: Exception =>
- logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " +
- checkpointFile)
- }
- }
+ override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = {
+ data.getCheckpointFiles
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
new file mode 100644
index 0000000000..f31ed2aa90
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import org.apache.spark.SparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * This class helps with persisting and checkpointing RDDs.
+ * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
+ * unpersisting and removing checkpoint files.
+ *
+ * Users should call update() when a new RDD has been created,
+ * before the RDD has been materialized. After updating [[PeriodicRDDCheckpointer]], users are
+ * responsible for materializing the RDD to ensure that persisting and checkpointing actually
+ * occur.
+ *
+ * When update() is called, this does the following:
+ * - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs.
+ * - Unpersist RDDs from queue until there are at most 3 persisted RDDs.
+ * - If using checkpointing and the checkpoint interval has been reached,
+ * - Checkpoint the new RDD, and put in a queue of checkpointed RDDs.
+ * - Remove older checkpoints.
+ *
+ * WARNINGS:
+ * - This class should NOT be copied (since copies may conflict on which RDDs should be
+ * checkpointed).
+ * - This class removes checkpoint files once later RDDs have been checkpointed.
+ * However, references to the older RDDs will still return isCheckpointed = true.
+ *
+ * Example usage:
+ * {{{
+ * val (rdd1, rdd2, rdd3, ...) = ...
+ * val cp = new PeriodicRDDCheckpointer(2, sc)
+ * rdd1.count();
+ * // persisted: rdd1
+ * cp.update(rdd2)
+ * rdd2.count();
+ * // persisted: rdd1, rdd2
+ * // checkpointed: rdd2
+ * cp.update(rdd3)
+ * rdd3.count();
+ * // persisted: rdd1, rdd2, rdd3
+ * // checkpointed: rdd2
+ * cp.update(rdd4)
+ * rdd4.count();
+ * // persisted: rdd2, rdd3, rdd4
+ * // checkpointed: rdd4
+ * cp.update(rdd5)
+ * rdd5.count();
+ * // persisted: rdd3, rdd4, rdd5
+ * // checkpointed: rdd4
+ * }}}
+ *
+ * @param checkpointInterval RDDs will be checkpointed at this interval
+ * @tparam T RDD element type
+ *
+ * TODO: Move this out of MLlib?
+ */
+private[mllib] class PeriodicRDDCheckpointer[T](
+ checkpointInterval: Int,
+ sc: SparkContext)
+ extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) {
+
+ override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint()
+
+ override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed
+
+ override protected def persist(data: RDD[T]): Unit = {
+ if (data.getStorageLevel == StorageLevel.NONE) {
+ data.persist()
+ }
+ }
+
+ override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false)
+
+ override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = {
+ data.getCheckpointFile.map(x => x)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
index d34888af2d..e331c75989 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
@@ -30,20 +30,20 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
import PeriodicGraphCheckpointerSuite._
- // TODO: Do I need to call count() on the graphs' RDDs?
-
test("Persisting") {
var graphsToCheck = Seq.empty[GraphToCheck]
val graph1 = createGraph(sc)
- val checkpointer = new PeriodicGraphCheckpointer(graph1, 10)
+ val checkpointer =
+ new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
+ checkpointer.update(graph1)
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
checkPersistence(graphsToCheck, 1)
var iteration = 2
while (iteration < 9) {
val graph = createGraph(sc)
- checkpointer.updateGraph(graph)
+ checkpointer.update(graph)
graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
checkPersistence(graphsToCheck, iteration)
iteration += 1
@@ -57,7 +57,9 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
var graphsToCheck = Seq.empty[GraphToCheck]
sc.setCheckpointDir(path)
val graph1 = createGraph(sc)
- val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval)
+ val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
+ checkpointInterval, graph1.vertices.sparkContext)
+ checkpointer.update(graph1)
graph1.edges.count()
graph1.vertices.count()
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
@@ -66,7 +68,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
var iteration = 2
while (iteration < 9) {
val graph = createGraph(sc)
- checkpointer.updateGraph(graph)
+ checkpointer.update(graph)
graph.vertices.count()
graph.edges.count()
graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
@@ -168,7 +170,7 @@ private object PeriodicGraphCheckpointerSuite {
} else {
// Graph should never be checkpointed
assert(!graph.isCheckpointed, "Graph should never have been checkpointed")
- assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files")
+ assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files")
}
} catch {
case e: AssertionError =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
new file mode 100644
index 0000000000..b2a459a68b
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import org.apache.hadoop.fs.{FileSystem, Path}
+
+import org.apache.spark.{SparkContext, SparkFunSuite}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+
+
+class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ import PeriodicRDDCheckpointerSuite._
+
+ test("Persisting") {
+ var rddsToCheck = Seq.empty[RDDToCheck]
+
+ val rdd1 = createRDD(sc)
+ val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext)
+ checkpointer.update(rdd1)
+ rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
+ checkPersistence(rddsToCheck, 1)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val rdd = createRDD(sc)
+ checkpointer.update(rdd)
+ rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration)
+ checkPersistence(rddsToCheck, iteration)
+ iteration += 1
+ }
+ }
+
+ test("Checkpointing") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ val checkpointInterval = 2
+ var rddsToCheck = Seq.empty[RDDToCheck]
+ sc.setCheckpointDir(path)
+ val rdd1 = createRDD(sc)
+ val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext)
+ checkpointer.update(rdd1)
+ rdd1.count()
+ rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
+ checkCheckpoint(rddsToCheck, 1, checkpointInterval)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val rdd = createRDD(sc)
+ checkpointer.update(rdd)
+ rdd.count()
+ rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration)
+ checkCheckpoint(rddsToCheck, iteration, checkpointInterval)
+ iteration += 1
+ }
+
+ checkpointer.deleteAllCheckpoints()
+ rddsToCheck.foreach { rdd =>
+ confirmCheckpointRemoved(rdd.rdd)
+ }
+
+ Utils.deleteRecursively(tempDir)
+ }
+}
+
+private object PeriodicRDDCheckpointerSuite {
+
+ case class RDDToCheck(rdd: RDD[Double], gIndex: Int)
+
+ def createRDD(sc: SparkContext): RDD[Double] = {
+ sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0))
+ }
+
+ def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = {
+ rdds.foreach { g =>
+ checkPersistence(g.rdd, g.gIndex, iteration)
+ }
+ }
+
+ /**
+ * Check storage level of rdd.
+ * @param gIndex Index of rdd in order inserted into checkpointer (from 1).
+ * @param iteration Total number of rdds inserted into checkpointer.
+ */
+ def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = {
+ try {
+ if (gIndex + 2 < iteration) {
+ assert(rdd.getStorageLevel == StorageLevel.NONE)
+ } else {
+ assert(rdd.getStorageLevel != StorageLevel.NONE)
+ }
+ } catch {
+ case _: AssertionError =>
+ throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" +
+ s"\t gIndex = $gIndex\n" +
+ s"\t iteration = $iteration\n" +
+ s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n")
+ }
+ }
+
+ def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = {
+ rdds.reverse.foreach { g =>
+ checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval)
+ }
+ }
+
+ def confirmCheckpointRemoved(rdd: RDD[_]): Unit = {
+ // Note: We cannot check rdd.isCheckpointed since that value is never updated.
+ // Instead, we check for the presence of the checkpoint files.
+ // This test should continue to work even after this rdd.isCheckpointed issue
+ // is fixed (though it can then be simplified and not look for the files).
+ val fs = FileSystem.get(rdd.sparkContext.hadoopConfiguration)
+ rdd.getCheckpointFile.foreach { checkpointFile =>
+ assert(!fs.exists(new Path(checkpointFile)), "RDD checkpoint file should have been removed")
+ }
+ }
+
+ /**
+ * Check checkpointed status of rdd.
+ * @param gIndex Index of rdd in order inserted into checkpointer (from 1).
+ * @param iteration Total number of rdds inserted into checkpointer.
+ */
+ def checkCheckpoint(
+ rdd: RDD[_],
+ gIndex: Int,
+ iteration: Int,
+ checkpointInterval: Int): Unit = {
+ try {
+ if (gIndex % checkpointInterval == 0) {
+ // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd)
+ // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint.
+ if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) {
+ assert(rdd.isCheckpointed, "RDD should be checkpointed")
+ assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files")
+ } else {
+ confirmCheckpointRemoved(rdd)
+ }
+ } else {
+ // RDD should never be checkpointed
+ assert(!rdd.isCheckpointed, "RDD should never have been checkpointed")
+ assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files")
+ }
+ } catch {
+ case e: AssertionError =>
+ throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" +
+ s"\t gIndex = $gIndex\n" +
+ s"\t iteration = $iteration\n" +
+ s"\t checkpointInterval = $checkpointInterval\n" +
+ s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" +
+ s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" +
+ s" AssertionError message: ${e.getMessage}")
+ }
+ }
+
+}