diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala | 47 |
1 files changed, 36 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala index 391f89aa14..5c12c9305b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -52,7 +52,8 @@ import org.apache.spark.storage.StorageLevel * - This class removes checkpoint files once later Datasets have been checkpointed. * However, references to the older Datasets will still return isCheckpointed = true. * - * @param checkpointInterval Datasets will be checkpointed at this interval + * @param checkpointInterval Datasets will be checkpointed at this interval. + * If this interval was set as -1, then checkpointing will be disabled. * @param sc SparkContext for the Datasets given to this checkpointer * @tparam T Dataset type, such as RDD[Double] */ @@ -89,7 +90,8 @@ private[mllib] abstract class PeriodicCheckpointer[T]( updateCount += 1 // Handle checkpointing (after persisting) - if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { + if (checkpointInterval != -1 && (updateCount % checkpointInterval) == 0 + && sc.getCheckpointDir.nonEmpty) { // Add new checkpoint before removing old checkpoints. checkpoint(newData) checkpointQueue.enqueue(newData) @@ -134,6 +136,24 @@ private[mllib] abstract class PeriodicCheckpointer[T]( } /** + * Call this at the end to delete any remaining checkpoint files, except for the last checkpoint. + * Note that there may not be any checkpoints at all. + */ + def deleteAllCheckpointsButLast(): Unit = { + while (checkpointQueue.size > 1) { + removeCheckpointFile() + } + } + + /** + * Get all current checkpoint files. + * This is useful in combination with [[deleteAllCheckpointsButLast()]]. + */ + def getAllCheckpointFiles: Array[String] = { + checkpointQueue.flatMap(getCheckpointFiles).toArray + } + + /** * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files. * This prints a warning but does not fail if the files cannot be removed. */ @@ -141,15 +161,20 @@ private[mllib] abstract class PeriodicCheckpointer[T]( val old = checkpointQueue.dequeue() // Since the old checkpoint is not deleted by Spark, we manually delete it. val fs = FileSystem.get(sc.hadoopConfiguration) - getCheckpointFiles(old).foreach { checkpointFile => - try { - fs.delete(new Path(checkpointFile), true) - } catch { - case e: Exception => - logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + - checkpointFile) - } - } + getCheckpointFiles(old).foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs)) } +} + +private[spark] object PeriodicCheckpointer extends Logging { + /** Delete a checkpoint file, and log a warning if deletion fails. */ + def removeCheckpointFile(path: String, fs: FileSystem): Unit = { + try { + fs.delete(new Path(path), true) + } catch { + case e: Exception => + logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + + path) + } + } } |