aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-04-07 19:48:33 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-07 19:48:33 -0700
commit953ff897e422570a329d0aec98d573d3fb66ab9a (patch)
treec71211492dc024e469a834c07440713e78f7c981 /mllib/src/test/scala
parent692c74840bc53debbb842db5372702f58207412c (diff)
downloadspark-953ff897e422570a329d0aec98d573d3fb66ab9a.tar.gz
spark-953ff897e422570a329d0aec98d573d3fb66ab9a.tar.bz2
spark-953ff897e422570a329d0aec98d573d3fb66ab9a.zip
[SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer
## What changes were proposed in this pull request? The EMLDAOptimizer should generally not delete its last checkpoint since that can cause failures when DistributedLDAModel methods are called (if any partitions need to be recovered from the checkpoint). This PR adds a "deleteLastCheckpoint" option which defaults to false. This is a change in behavior from Spark 1.6, in that the last checkpoint will not be removed by default. This involves adding the deleteLastCheckpoint option to both spark.ml and spark.mllib, and modifying PeriodicCheckpointer to support the option. This also: * Makes MLlibTestSparkContext extend TempDirectory and set the checkpointDir to tempDir * Updates LibSVMRelationSuite because of a name conflict with "tempDir" (and fixes a bug where it failed to delete a temp directory) * Adds a MIMA exclude for DistributedLDAModel constructor, which is already ```private[clustering]``` ## How was this patch tested? Added 2 new unit tests to spark.ml LDASuite, which calls into spark.mllib. Author: Joseph K. Bradley <joseph@databricks.com> Closes #12166 from jkbradley/emlda-save-checkpoint.
Diffstat (limited to 'mllib/src/test/scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala28
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala15
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala13
3 files changed, 47 insertions, 9 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index dd3f4c6e53..a1c93891c7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.clustering
+import org.apache.hadoop.fs.{FileSystem, Path}
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -261,4 +263,30 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
testEstimatorAndModelReadWrite(lda, dataset,
LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
}
+
+ test("EM LDA checkpointing: save last checkpoint") {
+ // Checkpoint dir is set by MLlibTestSparkContext
+ val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1)
+ val model_ = lda.fit(dataset)
+ assert(model_.isInstanceOf[DistributedLDAModel])
+ val model = model_.asInstanceOf[DistributedLDAModel]
+
+ // There should be 1 checkpoint remaining.
+ assert(model.getCheckpointFiles.length === 1)
+ val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration)
+ assert(fs.exists(new Path(model.getCheckpointFiles.head)))
+ model.deleteCheckpointFiles()
+ assert(model.getCheckpointFiles.isEmpty)
+ }
+
+ test("EM LDA checkpointing: remove last checkpoint") {
+ // Checkpoint dir is set by MLlibTestSparkContext
+ val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1)
+ .setKeepLastCheckpoint(false)
+ val model_ = lda.fit(dataset)
+ assert(model_.isInstanceOf[DistributedLDAModel])
+ val model = model_.asInstanceOf[DistributedLDAModel]
+
+ assert(model.getCheckpointFiles.isEmpty)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 114a238462..0bd14978b2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -28,8 +28,9 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.SaveMode
import org.apache.spark.util.Utils
+
class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
- var tempDir: File = _
+ // Path for dataset
var path: String = _
override def beforeAll(): Unit = {
@@ -40,15 +41,15 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
|0
|0 2:4.0 4:5.0 6:6.0
""".stripMargin
- tempDir = Utils.createTempDir()
- val file = new File(tempDir, "part-00000")
+ val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data")
+ val file = new File(dir, "part-00000")
Files.write(lines, file, StandardCharsets.UTF_8)
- path = tempDir.toURI.toString
+ path = dir.toURI.toString
}
override def afterAll(): Unit = {
try {
- Utils.deleteRecursively(tempDir)
+ Utils.deleteRecursively(new File(path))
} finally {
super.afterAll()
}
@@ -86,7 +87,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
test("write libsvm data and read it again") {
val df = sqlContext.read.format("libsvm").load(path)
- val tempDir2 = Utils.createTempDir()
+ val tempDir2 = new File(tempDir, "read_write_test")
val writepath = tempDir2.toURI.toString
// TODO: Remove requirement to coalesce by supporting multiple reads.
df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)
@@ -99,7 +100,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
test("write libsvm data failed due to invalid schema") {
val df = sqlContext.read.format("text").load(path)
- val e = intercept[SparkException] {
+ intercept[SparkException] {
df.write.format("libsvm").save(path + "_2")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
index ebcd591465..cb1efd5251 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
@@ -17,14 +17,20 @@
package org.apache.spark.mllib.util
-import org.scalatest.{BeforeAndAfterAll, Suite}
+import java.io.File
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.Suite
import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.ml.util.TempDirectory
import org.apache.spark.sql.SQLContext
+import org.apache.spark.util.Utils
-trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
+trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
@transient var sc: SparkContext = _
@transient var sqlContext: SQLContext = _
+ @transient var checkpointDir: String = _
override def beforeAll() {
super.beforeAll()
@@ -35,10 +41,13 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
SQLContext.clearActive()
sqlContext = new SQLContext(sc)
SQLContext.setActive(sqlContext)
+ checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString
+ sc.setCheckpointDir(checkpointDir)
}
override def afterAll() {
try {
+ Utils.deleteRecursively(new File(checkpointDir))
sqlContext = null
SQLContext.clearActive()
if (sc != null) {