diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala | 43 |
1 files changed, 41 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index dd3f4c6e53..17d6e9fc2e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.ml.clustering +import org.apache.hadoop.fs.{FileSystem, Path} + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} object LDASuite { @@ -62,7 +64,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val k: Int = 5 val vocabSize: Int = 30 - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -261,4 +263,41 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) } + + test("EM LDA checkpointing: save last checkpoint") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + // There should be 1 checkpoint remaining. + assert(model.getCheckpointFiles.length === 1) + val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) + assert(fs.exists(new Path(model.getCheckpointFiles.head))) + model.deleteCheckpointFiles() + assert(model.getCheckpointFiles.isEmpty) + } + + test("EM LDA checkpointing: remove last checkpoint") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1) + .setKeepLastCheckpoint(false) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + assert(model.getCheckpointFiles.isEmpty) + } + + test("EM LDA disable checkpointing") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3) + .setCheckpointInterval(-1) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + assert(model.getCheckpointFiles.isEmpty) + } } |