From b13162b364aeff35e3bdeea9c9a31e5ce66f8c9a Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sat, 9 May 2015 15:40:46 -0700 Subject: [SPARK-7475] [MLLIB] adjust ldaExample for online LDA jira: https://issues.apache.org/jira/browse/SPARK-7475 Add a new argument to specify the algorithm applied to LDA, to exhibit the basic usage of LDAOptimizer. cc jkbradley Author: Yuhao Yang Closes #6000 from hhbyyh/ldaExample and squashes the following commits: 0a7e2bc [Yuhao Yang] fix according to comments 5810b0f [Yuhao Yang] adjust ldaExample for online LDA --- .../apache/spark/examples/mllib/LDAExample.scala | 31 +++++++++++++++++----- 1 file changed, 25 insertions(+), 6 deletions(-) (limited to 'examples/src') diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index a1850390c0..31d629f853 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -26,7 +26,7 @@ import scopt.OptionParser import org.apache.log4j.{Level, Logger} import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA} +import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD @@ -48,6 +48,7 @@ object LDAExample { topicConcentration: Double = -1, vocabSize: Int = 10000, stopwordFile: String = "", + algorithm: String = "em", checkpointDir: Option[String] = None, checkpointInterval: Int = 10) extends AbstractParams[Params] @@ -78,6 +79,10 @@ object LDAExample { .text(s"filepath for a list of stopwords. Note: This must fit on a single machine." + s" default: ${defaultParams.stopwordFile}") .action((x, c) => c.copy(stopwordFile = x)) + opt[String]("algorithm") + .text(s"inference algorithm to use. em and online are supported." + + s" default: ${defaultParams.algorithm}") + .action((x, c) => c.copy(algorithm = x)) opt[String]("checkpointDir") .text(s"Directory for checkpointing intermediate results." + s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." + @@ -128,7 +133,17 @@ object LDAExample { // Run LDA. val lda = new LDA() - lda.setK(params.k) + + val optimizer = params.algorithm.toLowerCase match { + case "em" => new EMLDAOptimizer + // add (1.0 / actualCorpusSize) to MiniBatchFraction be more robust on tiny datasets. + case "online" => new OnlineLDAOptimizer().setMiniBatchFraction(0.05 + 1.0 / actualCorpusSize) + case _ => throw new IllegalArgumentException( + s"Only em, online are supported but got ${params.algorithm}.") + } + + lda.setOptimizer(optimizer) + .setK(params.k) .setMaxIterations(params.maxIterations) .setDocConcentration(params.docConcentration) .setTopicConcentration(params.topicConcentration) @@ -137,14 +152,18 @@ object LDAExample { sc.setCheckpointDir(params.checkpointDir.get) } val startTime = System.nanoTime() - val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] + val ldaModel = lda.run(corpus) val elapsed = (System.nanoTime() - startTime) / 1e9 println(s"Finished training LDA model. Summary:") println(s"\t Training time: $elapsed sec") - val avgLogLikelihood = ldaModel.logLikelihood / actualCorpusSize.toDouble - println(s"\t Training data average log likelihood: $avgLogLikelihood") - println() + + if (ldaModel.isInstanceOf[DistributedLDAModel]) { + val distLDAModel = ldaModel.asInstanceOf[DistributedLDAModel] + val avgLogLikelihood = distLDAModel.logLikelihood / actualCorpusSize.toDouble + println(s"\t Training data average log likelihood: $avgLogLikelihood") + println() + } // Print the topics, showing the top-weighted terms for each topic. val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) -- cgit v1.2.3