aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala43
1 files changed, 41 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index dd3f4c6e53..17d6e9fc2e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -17,12 +17,14 @@
package org.apache.spark.ml.clustering
+import org.apache.hadoop.fs.{FileSystem, Path}
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
object LDASuite {
@@ -62,7 +64,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val k: Int = 5
val vocabSize: Int = 30
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -261,4 +263,41 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
testEstimatorAndModelReadWrite(lda, dataset,
LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
}
+
+ test("EM LDA checkpointing: save last checkpoint") {
+ // Checkpoint dir is set by MLlibTestSparkContext
+ val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1)
+ val model_ = lda.fit(dataset)
+ assert(model_.isInstanceOf[DistributedLDAModel])
+ val model = model_.asInstanceOf[DistributedLDAModel]
+
+ // There should be 1 checkpoint remaining.
+ assert(model.getCheckpointFiles.length === 1)
+ val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration)
+ assert(fs.exists(new Path(model.getCheckpointFiles.head)))
+ model.deleteCheckpointFiles()
+ assert(model.getCheckpointFiles.isEmpty)
+ }
+
+ test("EM LDA checkpointing: remove last checkpoint") {
+ // Checkpoint dir is set by MLlibTestSparkContext
+ val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1)
+ .setKeepLastCheckpoint(false)
+ val model_ = lda.fit(dataset)
+ assert(model_.isInstanceOf[DistributedLDAModel])
+ val model = model_.asInstanceOf[DistributedLDAModel]
+
+ assert(model.getCheckpointFiles.isEmpty)
+ }
+
+ test("EM LDA disable checkpointing") {
+ // Checkpoint dir is set by MLlibTestSparkContext
+ val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3)
+ .setCheckpointInterval(-1)
+ val model_ = lda.fit(dataset)
+ assert(model_.isInstanceOf[DistributedLDAModel])
+ val model = model_.asInstanceOf[DistributedLDAModel]
+
+ assert(model.getCheckpointFiles.isEmpty)
+ }
}