diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2016-04-07 19:48:33 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-07 19:48:33 -0700 |
commit | 953ff897e422570a329d0aec98d573d3fb66ab9a (patch) | |
tree | c71211492dc024e469a834c07440713e78f7c981 /mllib/src/test/scala/org | |
parent | 692c74840bc53debbb842db5372702f58207412c (diff) | |
download | spark-953ff897e422570a329d0aec98d573d3fb66ab9a.tar.gz spark-953ff897e422570a329d0aec98d573d3fb66ab9a.tar.bz2 spark-953ff897e422570a329d0aec98d573d3fb66ab9a.zip |
[SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer
## What changes were proposed in this pull request?
The EMLDAOptimizer should generally not delete its last checkpoint since that can cause failures when DistributedLDAModel methods are called (if any partitions need to be recovered from the checkpoint).
This PR adds a "deleteLastCheckpoint" option which defaults to false. This is a change in behavior from Spark 1.6, in that the last checkpoint will not be removed by default.
This involves adding the deleteLastCheckpoint option to both spark.ml and spark.mllib, and modifying PeriodicCheckpointer to support the option.
This also:
* Makes MLlibTestSparkContext extend TempDirectory and set the checkpointDir to tempDir
* Updates LibSVMRelationSuite because of a name conflict with "tempDir" (and fixes a bug where it failed to delete a temp directory)
* Adds a MIMA exclude for DistributedLDAModel constructor, which is already ```private[clustering]```
## How was this patch tested?
Added 2 new unit tests to spark.ml LDASuite, which calls into spark.mllib.
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #12166 from jkbradley/emlda-save-checkpoint.
Diffstat (limited to 'mllib/src/test/scala/org')
3 files changed, 47 insertions, 9 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index dd3f4c6e53..a1c93891c7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import org.apache.hadoop.fs.{FileSystem, Path} + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -261,4 +263,30 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) } + + test("EM LDA checkpointing: save last checkpoint") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + // There should be 1 checkpoint remaining. + assert(model.getCheckpointFiles.length === 1) + val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) + assert(fs.exists(new Path(model.getCheckpointFiles.head))) + model.deleteCheckpointFiles() + assert(model.getCheckpointFiles.isEmpty) + } + + test("EM LDA checkpointing: remove last checkpoint") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1) + .setKeepLastCheckpoint(false) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + assert(model.getCheckpointFiles.isEmpty) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 114a238462..0bd14978b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -28,8 +28,9 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.SaveMode import org.apache.spark.util.Utils + class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { - var tempDir: File = _ + // Path for dataset var path: String = _ override def beforeAll(): Unit = { @@ -40,15 +41,15 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { |0 |0 2:4.0 4:5.0 6:6.0 """.stripMargin - tempDir = Utils.createTempDir() - val file = new File(tempDir, "part-00000") + val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data") + val file = new File(dir, "part-00000") Files.write(lines, file, StandardCharsets.UTF_8) - path = tempDir.toURI.toString + path = dir.toURI.toString } override def afterAll(): Unit = { try { - Utils.deleteRecursively(tempDir) + Utils.deleteRecursively(new File(path)) } finally { super.afterAll() } @@ -86,7 +87,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("write libsvm data and read it again") { val df = sqlContext.read.format("libsvm").load(path) - val tempDir2 = Utils.createTempDir() + val tempDir2 = new File(tempDir, "read_write_test") val writepath = tempDir2.toURI.toString // TODO: Remove requirement to coalesce by supporting multiple reads. df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath) @@ -99,7 +100,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("write libsvm data failed due to invalid schema") { val df = sqlContext.read.format("text").load(path) - val e = intercept[SparkException] { + intercept[SparkException] { df.write.format("libsvm").save(path + "_2") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index ebcd591465..cb1efd5251 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -17,14 +17,20 @@ package org.apache.spark.mllib.util -import org.scalatest.{BeforeAndAfterAll, Suite} +import java.io.File + +import org.apache.hadoop.fs.Path +import org.scalatest.Suite import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.ml.util.TempDirectory import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils -trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => +trait MLlibTestSparkContext extends TempDirectory { self: Suite => @transient var sc: SparkContext = _ @transient var sqlContext: SQLContext = _ + @transient var checkpointDir: String = _ override def beforeAll() { super.beforeAll() @@ -35,10 +41,13 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => SQLContext.clearActive() sqlContext = new SQLContext(sc) SQLContext.setActive(sqlContext) + checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString + sc.setCheckpointDir(checkpointDir) } override def afterAll() { try { + Utils.deleteRecursively(new File(checkpointDir)) sqlContext = null SQLContext.clearActive() if (sc != null) { |