diff options
Diffstat (limited to 'mllib/src')
7 files changed, 191 insertions, 32 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 60cc345565..727b724708 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -17,18 +17,19 @@ package org.apache.spark.ml.clustering -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed} import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, - EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, - LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, - OnlineLDAOptimizer => OldOnlineLDAOptimizer} + EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, + LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, + OnlineLDAOptimizer => OldOnlineLDAOptimizer} +import org.apache.spark.mllib.impl.PeriodicCheckpointer import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SQLContext} @@ -41,6 +42,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM /** * Param for the number of topics (clusters) to infer. Must be > 1. Default: 10. + * * @group param */ @Since("1.6.0") @@ -173,6 +175,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * This uses a variational approximation following Hoffman et al. (2010), where the approximate * distribution is called "gamma." Technically, this method returns this approximation "gamma" * for each document. + * * @group param */ @Since("1.6.0") @@ -191,6 +194,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * iterations count less. * This is called "tau0" in the Online LDA paper (Hoffman et al., 2010) * Default: 1024, following Hoffman et al. + * * @group expertParam */ @Since("1.6.0") @@ -207,6 +211,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * This should be between (0.5, 1.0] to guarantee asymptotic convergence. * This is called "kappa" in the Online LDA paper (Hoffman et al., 2010). * Default: 0.51, based on Hoffman et al. + * * @group expertParam */ @Since("1.6.0") @@ -230,6 +235,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * [[org.apache.spark.mllib.clustering.OnlineLDAOptimizer]]. * * Default: 0.05, i.e., 5% of total documents. + * * @group param */ @Since("1.6.0") @@ -246,6 +252,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * document-topic distribution) will be optimized during training. * Setting this to true will make the model more expressive and fit the training data better. * Default: false + * * @group expertParam */ @Since("1.6.0") @@ -258,7 +265,31 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM def getOptimizeDocConcentration: Boolean = $(optimizeDocConcentration) /** + * For EM optimizer, if using checkpointing, this indicates whether to keep the last + * checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can + * cause failures if a data partition is lost, so set this bit with care. + * Note that checkpoints will be cleaned up via reference counting, regardless. + * + * See [[DistributedLDAModel.getCheckpointFiles]] for getting remaining checkpoints and + * [[DistributedLDAModel.deleteCheckpointFiles]] for removing remaining checkpoints. + * + * Default: true + * + * @group expertParam + */ + @Since("2.0.0") + final val keepLastCheckpoint = new BooleanParam(this, "keepLastCheckpoint", + "For EM optimizer, if using checkpointing, this indicates whether to keep the last" + + " checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can" + + " cause failures if a data partition is lost, so set this bit with care.") + + /** @group expertGetParam */ + @Since("2.0.0") + def getKeepLastCheckpoint: Boolean = $(keepLastCheckpoint) + + /** * Validates and transforms the input schema. + * * @param schema input schema * @return output schema */ @@ -303,6 +334,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM .setOptimizeDocConcentration($(optimizeDocConcentration)) case "em" => new OldEMLDAOptimizer() + .setKeepLastCheckpoint($(keepLastCheckpoint)) } } @@ -341,6 +373,7 @@ sealed abstract class LDAModel private[ml] ( /** * The features for LDA should be a [[Vector]] representing the word counts in a document. * The vector should be of length vocabSize, with counts for each term (word). + * * @group setParam */ @Since("1.6.0") @@ -619,6 +652,35 @@ class DistributedLDAModel private[ml] ( @Since("1.6.0") lazy val logPrior: Double = oldDistributedModel.logPrior + private var _checkpointFiles: Array[String] = oldDistributedModel.checkpointFiles + + /** + * If using checkpointing and [[LDA.keepLastCheckpoint]] is set to true, then there may be + * saved checkpoint files. This method is provided so that users can manage those files. + * + * Note that removing the checkpoints can cause failures if a partition is lost and is needed + * by certain [[DistributedLDAModel]] methods. Reference counting will clean up the checkpoints + * when this model and derivative data go out of scope. + * + * @return Checkpoint files from training + */ + @DeveloperApi + @Since("2.0.0") + def getCheckpointFiles: Array[String] = _checkpointFiles + + /** + * Remove any remaining checkpoint files from training. + * + * @see [[getCheckpointFiles]] + */ + @DeveloperApi + @Since("2.0.0") + def deleteCheckpointFiles(): Unit = { + val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) + _checkpointFiles.foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs)) + _checkpointFiles = Array.empty[String] + } + @Since("1.6.0") override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this) } @@ -696,11 +758,12 @@ class LDA @Since("1.6.0") ( setDefault(maxIter -> 20, k -> 10, optimizer -> "online", checkpointInterval -> 10, learningOffset -> 1024, learningDecay -> 0.51, subsamplingRate -> 0.05, - optimizeDocConcentration -> true) + optimizeDocConcentration -> true, keepLastCheckpoint -> true) /** * The features for LDA should be a [[Vector]] representing the word counts in a document. * The vector should be of length vocabSize, with counts for each term (word). + * * @group setParam */ @Since("1.6.0") @@ -758,6 +821,10 @@ class LDA @Since("1.6.0") ( @Since("1.6.0") def setOptimizeDocConcentration(value: Boolean): this.type = set(optimizeDocConcentration, value) + /** @group expertSetParam */ + @Since("2.0.0") + def setKeepLastCheckpoint(value: Boolean): this.type = set(keepLastCheckpoint, value) + @Since("1.6.0") override def copy(extra: ParamMap): LDA = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 25d67a3756..27b4004927 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -534,7 +534,8 @@ class DistributedLDAModel private[clustering] ( @Since("1.5.0") override val docConcentration: Vector, @Since("1.5.0") override val topicConcentration: Double, private[spark] val iterationTimes: Array[Double], - override protected[clustering] val gammaShape: Double = 100) + override protected[clustering] val gammaShape: Double = DistributedLDAModel.defaultGammaShape, + private[spark] val checkpointFiles: Array[String] = Array.empty[String]) extends LDAModel { import LDA._ @@ -806,11 +807,9 @@ class DistributedLDAModel private[clustering] ( override protected def formatVersion = "1.0" - /** - * Java-friendly version of [[topicDistributions]] - */ @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { + // Note: This intentionally does not save checkpointFiles. DistributedLDAModel.SaveLoadV1_0.save( sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration, iterationTimes, gammaShape) @@ -822,6 +821,12 @@ class DistributedLDAModel private[clustering] ( @Since("1.5.0") object DistributedLDAModel extends Loader[DistributedLDAModel] { + /** + * The [[DistributedLDAModel]] constructor's default arguments assume gammaShape = 100 + * to ensure equivalence in LDAModel.toLocal conversion. + */ + private[clustering] val defaultGammaShape: Double = 100 + private object SaveLoadV1_0 { val thisFormatVersion = "1.0" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 2b404a8651..6418f0d3b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -80,9 +80,29 @@ final class EMLDAOptimizer extends LDAOptimizer { import LDA._ + // Adjustable parameters + private var keepLastCheckpoint: Boolean = true + /** - * The following fields will only be initialized through the initialize() method + * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up). + */ + @Since("2.0.0") + def getKeepLastCheckpoint: Boolean = this.keepLastCheckpoint + + /** + * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up). + * Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with + * care. Note that checkpoints will be cleaned up via reference counting, regardless. + * + * Default: true */ + @Since("2.0.0") + def setKeepLastCheckpoint(keepLastCheckpoint: Boolean): this.type = { + this.keepLastCheckpoint = keepLastCheckpoint + this + } + + // The following fields will only be initialized through the initialize() method private[clustering] var graph: Graph[TopicCounts, TokenCount] = null private[clustering] var k: Int = 0 private[clustering] var vocabSize: Int = 0 @@ -208,12 +228,18 @@ final class EMLDAOptimizer extends LDAOptimizer { override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") - this.graphCheckpointer.deleteAllCheckpoints() + val checkpointFiles: Array[String] = if (keepLastCheckpoint) { + this.graphCheckpointer.deleteAllCheckpointsButLast() + this.graphCheckpointer.getAllCheckpointFiles + } else { + this.graphCheckpointer.deleteAllCheckpoints() + Array.empty[String] + } // The constructor's default arguments assume gammaShape = 100 to ensure equivalence in - // LDAModel.toLocal conversion + // LDAModel.toLocal conversion. new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize, Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration, - iterationTimes) + iterationTimes, DistributedLDAModel.defaultGammaShape, checkpointFiles) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala index 391f89aa14..cbc8f60112 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -134,6 +134,24 @@ private[mllib] abstract class PeriodicCheckpointer[T]( } /** + * Call this at the end to delete any remaining checkpoint files, except for the last checkpoint. + * Note that there may not be any checkpoints at all. + */ + def deleteAllCheckpointsButLast(): Unit = { + while (checkpointQueue.size > 1) { + removeCheckpointFile() + } + } + + /** + * Get all current checkpoint files. + * This is useful in combination with [[deleteAllCheckpointsButLast()]]. + */ + def getAllCheckpointFiles: Array[String] = { + checkpointQueue.flatMap(getCheckpointFiles).toArray + } + + /** * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files. * This prints a warning but does not fail if the files cannot be removed. */ @@ -141,15 +159,20 @@ private[mllib] abstract class PeriodicCheckpointer[T]( val old = checkpointQueue.dequeue() // Since the old checkpoint is not deleted by Spark, we manually delete it. val fs = FileSystem.get(sc.hadoopConfiguration) - getCheckpointFiles(old).foreach { checkpointFile => - try { - fs.delete(new Path(checkpointFile), true) - } catch { - case e: Exception => - logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + - checkpointFile) - } - } + getCheckpointFiles(old).foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs)) } +} + +private[spark] object PeriodicCheckpointer extends Logging { + /** Delete a checkpoint file, and log a warning if deletion fails. */ + def removeCheckpointFile(path: String, fs: FileSystem): Unit = { + try { + fs.delete(new Path(path), true) + } catch { + case e: Exception => + logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + + path) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index dd3f4c6e53..a1c93891c7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import org.apache.hadoop.fs.{FileSystem, Path} + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -261,4 +263,30 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) } + + test("EM LDA checkpointing: save last checkpoint") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + // There should be 1 checkpoint remaining. + assert(model.getCheckpointFiles.length === 1) + val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) + assert(fs.exists(new Path(model.getCheckpointFiles.head))) + model.deleteCheckpointFiles() + assert(model.getCheckpointFiles.isEmpty) + } + + test("EM LDA checkpointing: remove last checkpoint") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1) + .setKeepLastCheckpoint(false) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + assert(model.getCheckpointFiles.isEmpty) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 114a238462..0bd14978b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -28,8 +28,9 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.SaveMode import org.apache.spark.util.Utils + class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { - var tempDir: File = _ + // Path for dataset var path: String = _ override def beforeAll(): Unit = { @@ -40,15 +41,15 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { |0 |0 2:4.0 4:5.0 6:6.0 """.stripMargin - tempDir = Utils.createTempDir() - val file = new File(tempDir, "part-00000") + val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data") + val file = new File(dir, "part-00000") Files.write(lines, file, StandardCharsets.UTF_8) - path = tempDir.toURI.toString + path = dir.toURI.toString } override def afterAll(): Unit = { try { - Utils.deleteRecursively(tempDir) + Utils.deleteRecursively(new File(path)) } finally { super.afterAll() } @@ -86,7 +87,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("write libsvm data and read it again") { val df = sqlContext.read.format("libsvm").load(path) - val tempDir2 = Utils.createTempDir() + val tempDir2 = new File(tempDir, "read_write_test") val writepath = tempDir2.toURI.toString // TODO: Remove requirement to coalesce by supporting multiple reads. df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath) @@ -99,7 +100,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("write libsvm data failed due to invalid schema") { val df = sqlContext.read.format("text").load(path) - val e = intercept[SparkException] { + intercept[SparkException] { df.write.format("libsvm").save(path + "_2") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index ebcd591465..cb1efd5251 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -17,14 +17,20 @@ package org.apache.spark.mllib.util -import org.scalatest.{BeforeAndAfterAll, Suite} +import java.io.File + +import org.apache.hadoop.fs.Path +import org.scalatest.Suite import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.ml.util.TempDirectory import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils -trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => +trait MLlibTestSparkContext extends TempDirectory { self: Suite => @transient var sc: SparkContext = _ @transient var sqlContext: SQLContext = _ + @transient var checkpointDir: String = _ override def beforeAll() { super.beforeAll() @@ -35,10 +41,13 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => SQLContext.clearActive() sqlContext = new SQLContext(sc) SQLContext.setActive(sqlContext) + checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString + sc.setCheckpointDir(checkpointDir) } override def afterAll() { try { + Utils.deleteRecursively(new File(checkpointDir)) sqlContext = null SQLContext.clearActive() if (sc != null) { |