aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala47
1 files changed, 36 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
index 391f89aa14..5c12c9305b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
@@ -52,7 +52,8 @@ import org.apache.spark.storage.StorageLevel
* - This class removes checkpoint files once later Datasets have been checkpointed.
* However, references to the older Datasets will still return isCheckpointed = true.
*
- * @param checkpointInterval Datasets will be checkpointed at this interval
+ * @param checkpointInterval Datasets will be checkpointed at this interval.
+ * If this interval was set as -1, then checkpointing will be disabled.
* @param sc SparkContext for the Datasets given to this checkpointer
* @tparam T Dataset type, such as RDD[Double]
*/
@@ -89,7 +90,8 @@ private[mllib] abstract class PeriodicCheckpointer[T](
updateCount += 1
// Handle checkpointing (after persisting)
- if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
+ if (checkpointInterval != -1 && (updateCount % checkpointInterval) == 0
+ && sc.getCheckpointDir.nonEmpty) {
// Add new checkpoint before removing old checkpoints.
checkpoint(newData)
checkpointQueue.enqueue(newData)
@@ -134,6 +136,24 @@ private[mllib] abstract class PeriodicCheckpointer[T](
}
/**
+ * Call this at the end to delete any remaining checkpoint files, except for the last checkpoint.
+ * Note that there may not be any checkpoints at all.
+ */
+ def deleteAllCheckpointsButLast(): Unit = {
+ while (checkpointQueue.size > 1) {
+ removeCheckpointFile()
+ }
+ }
+
+ /**
+ * Get all current checkpoint files.
+ * This is useful in combination with [[deleteAllCheckpointsButLast()]].
+ */
+ def getAllCheckpointFiles: Array[String] = {
+ checkpointQueue.flatMap(getCheckpointFiles).toArray
+ }
+
+ /**
* Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
* This prints a warning but does not fail if the files cannot be removed.
*/
@@ -141,15 +161,20 @@ private[mllib] abstract class PeriodicCheckpointer[T](
val old = checkpointQueue.dequeue()
// Since the old checkpoint is not deleted by Spark, we manually delete it.
val fs = FileSystem.get(sc.hadoopConfiguration)
- getCheckpointFiles(old).foreach { checkpointFile =>
- try {
- fs.delete(new Path(checkpointFile), true)
- } catch {
- case e: Exception =>
- logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
- checkpointFile)
- }
- }
+ getCheckpointFiles(old).foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs))
}
+}
+
+private[spark] object PeriodicCheckpointer extends Logging {
+ /** Delete a checkpoint file, and log a warning if deletion fails. */
+ def removeCheckpointFile(path: String, fs: FileSystem): Unit = {
+ try {
+ fs.delete(new Path(path), true)
+ } catch {
+ case e: Exception =>
+ logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
+ path)
+ }
+ }
}