aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorHolden Karau <holden@us.ibm.com>2016-05-03 00:18:10 -0700
committerXiangrui Meng <meng@databricks.com>2016-05-03 00:18:10 -0700
commitf10ae4b1e169495af11b8e8123c60dd96174477e (patch)
tree323357bac8e3b933780625c92292a53c7043e17a /mllib/src/main
parentd8f528ceb61e3c2ac7ac97cd8147dafbb625932f (diff)
downloadspark-f10ae4b1e169495af11b8e8123c60dd96174477e.tar.gz
spark-f10ae4b1e169495af11b8e8123c60dd96174477e.tar.bz2
spark-f10ae4b1e169495af11b8e8123c60dd96174477e.zip
[SPARK-6717][ML] Clear shuffle files after checkpointing in ALS
## What changes were proposed in this pull request? When ALS is run with a checkpoint interval, during the checkpoint materialize the current state and cleanup the previous shuffles (non-blocking). ## How was this patch tested? Existing ALS unit tests, new ALS checkpoint cleanup unit tests added & shuffle files checked after ALS w/checkpointing run. Author: Holden Karau <holden@us.ibm.com> Author: Holden Karau <holden@pigscanfly.ca> Closes #11919 from holdenk/SPARK-6717-clear-shuffle-files-after-checkpointing-in-ALS.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala35
1 files changed, 33 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 541923048a..509c944fed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -30,7 +30,7 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
-import org.apache.spark.Partitioner
+import org.apache.spark.{Dependency, Partitioner, ShuffleDependency, SparkContext}
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
@@ -706,13 +706,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
previousItemFactors.unpersist()
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
// TODO: Generalize PeriodicGraphCheckpointer and use it here.
+ val deps = itemFactors.dependencies
if (shouldCheckpoint(iter)) {
- itemFactors.checkpoint() // itemFactors gets materialized in computeFactors.
+ itemFactors.checkpoint() // itemFactors gets materialized in computeFactors
}
val previousUserFactors = userFactors
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
if (shouldCheckpoint(iter)) {
+ ALS.cleanShuffleDependencies(sc, deps)
deletePreviousCheckpointFile()
previousCheckpointFile = itemFactors.getCheckpointFile
}
@@ -723,8 +725,10 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, solver = solver)
if (shouldCheckpoint(iter)) {
+ val deps = itemFactors.dependencies
itemFactors.checkpoint()
itemFactors.count() // checkpoint item factors and cut lineage
+ ALS.cleanShuffleDependencies(sc, deps)
deletePreviousCheckpointFile()
previousCheckpointFile = itemFactors.getCheckpointFile
}
@@ -1355,4 +1359,31 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* satisfies this requirement, we simply use a type alias here.
*/
private[recommendation] type ALSPartitioner = org.apache.spark.HashPartitioner
+
+ /**
+ * Private function to clean up all of the shuffles files from the dependencies and their parents.
+ */
+ private[spark] def cleanShuffleDependencies[T](
+ sc: SparkContext,
+ deps: Seq[Dependency[_]],
+ blocking: Boolean = false): Unit = {
+ // If there is no reference tracking we skip clean up.
+ sc.cleaner.foreach { cleaner =>
+ /**
+ * Clean the shuffles & all of its parents.
+ */
+ def cleanEagerly(dep: Dependency[_]): Unit = {
+ if (dep.isInstanceOf[ShuffleDependency[_, _, _]]) {
+ val shuffleId = dep.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
+ cleaner.doCleanupShuffle(shuffleId, blocking)
+ }
+ val rdd = dep.rdd
+ val rddDeps = rdd.dependencies
+ if (rdd.getStorageLevel == StorageLevel.NONE && rddDeps != null) {
+ rddDeps.foreach(cleanEagerly)
+ }
+ }
+ deps.foreach(cleanEagerly)
+ }
+ }
}