aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala79
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala34
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala41
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala28
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala15
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala13
-rw-r--r--project/MimaExcludes.scala3
8 files changed, 194 insertions, 32 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 60cc345565..727b724708 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -17,18 +17,19 @@
package org.apache.spark.ml.clustering
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
- EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
- LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
- OnlineLDAOptimizer => OldOnlineLDAOptimizer}
+ EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
+ LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
+ OnlineLDAOptimizer => OldOnlineLDAOptimizer}
+import org.apache.spark.mllib.impl.PeriodicCheckpointer
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
@@ -41,6 +42,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
/**
* Param for the number of topics (clusters) to infer. Must be > 1. Default: 10.
+ *
* @group param
*/
@Since("1.6.0")
@@ -173,6 +175,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* This uses a variational approximation following Hoffman et al. (2010), where the approximate
* distribution is called "gamma." Technically, this method returns this approximation "gamma"
* for each document.
+ *
* @group param
*/
@Since("1.6.0")
@@ -191,6 +194,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* iterations count less.
* This is called "tau0" in the Online LDA paper (Hoffman et al., 2010)
* Default: 1024, following Hoffman et al.
+ *
* @group expertParam
*/
@Since("1.6.0")
@@ -207,6 +211,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* This should be between (0.5, 1.0] to guarantee asymptotic convergence.
* This is called "kappa" in the Online LDA paper (Hoffman et al., 2010).
* Default: 0.51, based on Hoffman et al.
+ *
* @group expertParam
*/
@Since("1.6.0")
@@ -230,6 +235,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* [[org.apache.spark.mllib.clustering.OnlineLDAOptimizer]].
*
* Default: 0.05, i.e., 5% of total documents.
+ *
* @group param
*/
@Since("1.6.0")
@@ -246,6 +252,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* document-topic distribution) will be optimized during training.
* Setting this to true will make the model more expressive and fit the training data better.
* Default: false
+ *
* @group expertParam
*/
@Since("1.6.0")
@@ -258,7 +265,31 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
def getOptimizeDocConcentration: Boolean = $(optimizeDocConcentration)
/**
+ * For EM optimizer, if using checkpointing, this indicates whether to keep the last
+ * checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can
+ * cause failures if a data partition is lost, so set this bit with care.
+ * Note that checkpoints will be cleaned up via reference counting, regardless.
+ *
+ * See [[DistributedLDAModel.getCheckpointFiles]] for getting remaining checkpoints and
+ * [[DistributedLDAModel.deleteCheckpointFiles]] for removing remaining checkpoints.
+ *
+ * Default: true
+ *
+ * @group expertParam
+ */
+ @Since("2.0.0")
+ final val keepLastCheckpoint = new BooleanParam(this, "keepLastCheckpoint",
+ "For EM optimizer, if using checkpointing, this indicates whether to keep the last" +
+ " checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can" +
+ " cause failures if a data partition is lost, so set this bit with care.")
+
+ /** @group expertGetParam */
+ @Since("2.0.0")
+ def getKeepLastCheckpoint: Boolean = $(keepLastCheckpoint)
+
+ /**
* Validates and transforms the input schema.
+ *
* @param schema input schema
* @return output schema
*/
@@ -303,6 +334,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
.setOptimizeDocConcentration($(optimizeDocConcentration))
case "em" =>
new OldEMLDAOptimizer()
+ .setKeepLastCheckpoint($(keepLastCheckpoint))
}
}
@@ -341,6 +373,7 @@ sealed abstract class LDAModel private[ml] (
/**
* The features for LDA should be a [[Vector]] representing the word counts in a document.
* The vector should be of length vocabSize, with counts for each term (word).
+ *
* @group setParam
*/
@Since("1.6.0")
@@ -619,6 +652,35 @@ class DistributedLDAModel private[ml] (
@Since("1.6.0")
lazy val logPrior: Double = oldDistributedModel.logPrior
+ private var _checkpointFiles: Array[String] = oldDistributedModel.checkpointFiles
+
+ /**
+ * If using checkpointing and [[LDA.keepLastCheckpoint]] is set to true, then there may be
+ * saved checkpoint files. This method is provided so that users can manage those files.
+ *
+ * Note that removing the checkpoints can cause failures if a partition is lost and is needed
+ * by certain [[DistributedLDAModel]] methods. Reference counting will clean up the checkpoints
+ * when this model and derivative data go out of scope.
+ *
+ * @return Checkpoint files from training
+ */
+ @DeveloperApi
+ @Since("2.0.0")
+ def getCheckpointFiles: Array[String] = _checkpointFiles
+
+ /**
+ * Remove any remaining checkpoint files from training.
+ *
+ * @see [[getCheckpointFiles]]
+ */
+ @DeveloperApi
+ @Since("2.0.0")
+ def deleteCheckpointFiles(): Unit = {
+ val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration)
+ _checkpointFiles.foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs))
+ _checkpointFiles = Array.empty[String]
+ }
+
@Since("1.6.0")
override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this)
}
@@ -696,11 +758,12 @@ class LDA @Since("1.6.0") (
setDefault(maxIter -> 20, k -> 10, optimizer -> "online", checkpointInterval -> 10,
learningOffset -> 1024, learningDecay -> 0.51, subsamplingRate -> 0.05,
- optimizeDocConcentration -> true)
+ optimizeDocConcentration -> true, keepLastCheckpoint -> true)
/**
* The features for LDA should be a [[Vector]] representing the word counts in a document.
* The vector should be of length vocabSize, with counts for each term (word).
+ *
* @group setParam
*/
@Since("1.6.0")
@@ -758,6 +821,10 @@ class LDA @Since("1.6.0") (
@Since("1.6.0")
def setOptimizeDocConcentration(value: Boolean): this.type = set(optimizeDocConcentration, value)
+ /** @group expertSetParam */
+ @Since("2.0.0")
+ def setKeepLastCheckpoint(value: Boolean): this.type = set(keepLastCheckpoint, value)
+
@Since("1.6.0")
override def copy(extra: ParamMap): LDA = defaultCopy(extra)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 25d67a3756..27b4004927 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -534,7 +534,8 @@ class DistributedLDAModel private[clustering] (
@Since("1.5.0") override val docConcentration: Vector,
@Since("1.5.0") override val topicConcentration: Double,
private[spark] val iterationTimes: Array[Double],
- override protected[clustering] val gammaShape: Double = 100)
+ override protected[clustering] val gammaShape: Double = DistributedLDAModel.defaultGammaShape,
+ private[spark] val checkpointFiles: Array[String] = Array.empty[String])
extends LDAModel {
import LDA._
@@ -806,11 +807,9 @@ class DistributedLDAModel private[clustering] (
override protected def formatVersion = "1.0"
- /**
- * Java-friendly version of [[topicDistributions]]
- */
@Since("1.5.0")
override def save(sc: SparkContext, path: String): Unit = {
+ // Note: This intentionally does not save checkpointFiles.
DistributedLDAModel.SaveLoadV1_0.save(
sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration,
iterationTimes, gammaShape)
@@ -822,6 +821,12 @@ class DistributedLDAModel private[clustering] (
@Since("1.5.0")
object DistributedLDAModel extends Loader[DistributedLDAModel] {
+ /**
+ * The [[DistributedLDAModel]] constructor's default arguments assume gammaShape = 100
+ * to ensure equivalence in LDAModel.toLocal conversion.
+ */
+ private[clustering] val defaultGammaShape: Double = 100
+
private object SaveLoadV1_0 {
val thisFormatVersion = "1.0"
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 2b404a8651..6418f0d3b3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -80,9 +80,29 @@ final class EMLDAOptimizer extends LDAOptimizer {
import LDA._
+ // Adjustable parameters
+ private var keepLastCheckpoint: Boolean = true
+
/**
- * The following fields will only be initialized through the initialize() method
+ * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up).
+ */
+ @Since("2.0.0")
+ def getKeepLastCheckpoint: Boolean = this.keepLastCheckpoint
+
+ /**
+ * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up).
+ * Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with
+ * care. Note that checkpoints will be cleaned up via reference counting, regardless.
+ *
+ * Default: true
*/
+ @Since("2.0.0")
+ def setKeepLastCheckpoint(keepLastCheckpoint: Boolean): this.type = {
+ this.keepLastCheckpoint = keepLastCheckpoint
+ this
+ }
+
+ // The following fields will only be initialized through the initialize() method
private[clustering] var graph: Graph[TopicCounts, TokenCount] = null
private[clustering] var k: Int = 0
private[clustering] var vocabSize: Int = 0
@@ -208,12 +228,18 @@ final class EMLDAOptimizer extends LDAOptimizer {
override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
- this.graphCheckpointer.deleteAllCheckpoints()
+ val checkpointFiles: Array[String] = if (keepLastCheckpoint) {
+ this.graphCheckpointer.deleteAllCheckpointsButLast()
+ this.graphCheckpointer.getAllCheckpointFiles
+ } else {
+ this.graphCheckpointer.deleteAllCheckpoints()
+ Array.empty[String]
+ }
// The constructor's default arguments assume gammaShape = 100 to ensure equivalence in
- // LDAModel.toLocal conversion
+ // LDAModel.toLocal conversion.
new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize,
Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration,
- iterationTimes)
+ iterationTimes, DistributedLDAModel.defaultGammaShape, checkpointFiles)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
index 391f89aa14..cbc8f60112 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
@@ -134,6 +134,24 @@ private[mllib] abstract class PeriodicCheckpointer[T](
}
/**
+ * Call this at the end to delete any remaining checkpoint files, except for the last checkpoint.
+ * Note that there may not be any checkpoints at all.
+ */
+ def deleteAllCheckpointsButLast(): Unit = {
+ while (checkpointQueue.size > 1) {
+ removeCheckpointFile()
+ }
+ }
+
+ /**
+ * Get all current checkpoint files.
+ * This is useful in combination with [[deleteAllCheckpointsButLast()]].
+ */
+ def getAllCheckpointFiles: Array[String] = {
+ checkpointQueue.flatMap(getCheckpointFiles).toArray
+ }
+
+ /**
* Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
* This prints a warning but does not fail if the files cannot be removed.
*/
@@ -141,15 +159,20 @@ private[mllib] abstract class PeriodicCheckpointer[T](
val old = checkpointQueue.dequeue()
// Since the old checkpoint is not deleted by Spark, we manually delete it.
val fs = FileSystem.get(sc.hadoopConfiguration)
- getCheckpointFiles(old).foreach { checkpointFile =>
- try {
- fs.delete(new Path(checkpointFile), true)
- } catch {
- case e: Exception =>
- logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
- checkpointFile)
- }
- }
+ getCheckpointFiles(old).foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs))
}
+}
+
+private[spark] object PeriodicCheckpointer extends Logging {
+ /** Delete a checkpoint file, and log a warning if deletion fails. */
+ def removeCheckpointFile(path: String, fs: FileSystem): Unit = {
+ try {
+ fs.delete(new Path(path), true)
+ } catch {
+ case e: Exception =>
+ logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
+ path)
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index dd3f4c6e53..a1c93891c7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.clustering
+import org.apache.hadoop.fs.{FileSystem, Path}
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -261,4 +263,30 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
testEstimatorAndModelReadWrite(lda, dataset,
LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
}
+
+ test("EM LDA checkpointing: save last checkpoint") {
+ // Checkpoint dir is set by MLlibTestSparkContext
+ val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1)
+ val model_ = lda.fit(dataset)
+ assert(model_.isInstanceOf[DistributedLDAModel])
+ val model = model_.asInstanceOf[DistributedLDAModel]
+
+ // There should be 1 checkpoint remaining.
+ assert(model.getCheckpointFiles.length === 1)
+ val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration)
+ assert(fs.exists(new Path(model.getCheckpointFiles.head)))
+ model.deleteCheckpointFiles()
+ assert(model.getCheckpointFiles.isEmpty)
+ }
+
+ test("EM LDA checkpointing: remove last checkpoint") {
+ // Checkpoint dir is set by MLlibTestSparkContext
+ val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1)
+ .setKeepLastCheckpoint(false)
+ val model_ = lda.fit(dataset)
+ assert(model_.isInstanceOf[DistributedLDAModel])
+ val model = model_.asInstanceOf[DistributedLDAModel]
+
+ assert(model.getCheckpointFiles.isEmpty)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 114a238462..0bd14978b2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -28,8 +28,9 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.SaveMode
import org.apache.spark.util.Utils
+
class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
- var tempDir: File = _
+ // Path for dataset
var path: String = _
override def beforeAll(): Unit = {
@@ -40,15 +41,15 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
|0
|0 2:4.0 4:5.0 6:6.0
""".stripMargin
- tempDir = Utils.createTempDir()
- val file = new File(tempDir, "part-00000")
+ val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data")
+ val file = new File(dir, "part-00000")
Files.write(lines, file, StandardCharsets.UTF_8)
- path = tempDir.toURI.toString
+ path = dir.toURI.toString
}
override def afterAll(): Unit = {
try {
- Utils.deleteRecursively(tempDir)
+ Utils.deleteRecursively(new File(path))
} finally {
super.afterAll()
}
@@ -86,7 +87,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
test("write libsvm data and read it again") {
val df = sqlContext.read.format("libsvm").load(path)
- val tempDir2 = Utils.createTempDir()
+ val tempDir2 = new File(tempDir, "read_write_test")
val writepath = tempDir2.toURI.toString
// TODO: Remove requirement to coalesce by supporting multiple reads.
df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)
@@ -99,7 +100,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
test("write libsvm data failed due to invalid schema") {
val df = sqlContext.read.format("text").load(path)
- val e = intercept[SparkException] {
+ intercept[SparkException] {
df.write.format("libsvm").save(path + "_2")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
index ebcd591465..cb1efd5251 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
@@ -17,14 +17,20 @@
package org.apache.spark.mllib.util
-import org.scalatest.{BeforeAndAfterAll, Suite}
+import java.io.File
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.Suite
import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.ml.util.TempDirectory
import org.apache.spark.sql.SQLContext
+import org.apache.spark.util.Utils
-trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
+trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
@transient var sc: SparkContext = _
@transient var sqlContext: SQLContext = _
+ @transient var checkpointDir: String = _
override def beforeAll() {
super.beforeAll()
@@ -35,10 +41,13 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
SQLContext.clearActive()
sqlContext = new SQLContext(sc)
SQLContext.setActive(sqlContext)
+ checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString
+ sc.setCheckpointDir(checkpointDir)
}
override def afterAll() {
try {
+ Utils.deleteRecursively(new File(checkpointDir))
sqlContext = null
SQLContext.clearActive()
if (sc != null) {
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index fbadc563b8..a53161dc9a 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -614,6 +614,9 @@ object MimaExcludes {
) ++ Seq(
// [SPARK-13430][ML] moved featureCol from LinearRegressionModelSummary to LinearRegressionSummary
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this")
+ ) ++ Seq(
+ // [SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this")
)
case v if v.startsWith("1.6") =>
Seq(