aboutsummaryrefslogtreecommitdiff
path: root/examples/src
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-05-09 15:40:46 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-09 15:40:46 -0700
commitb13162b364aeff35e3bdeea9c9a31e5ce66f8c9a (patch)
tree2a4d463ccba951d81b0c634f423d7422672f05ab /examples/src
parentbd74301ff87f545e5808e13dd50dea12edd3db92 (diff)
downloadspark-b13162b364aeff35e3bdeea9c9a31e5ce66f8c9a.tar.gz
spark-b13162b364aeff35e3bdeea9c9a31e5ce66f8c9a.tar.bz2
spark-b13162b364aeff35e3bdeea9c9a31e5ce66f8c9a.zip
[SPARK-7475] [MLLIB] adjust ldaExample for online LDA
jira: https://issues.apache.org/jira/browse/SPARK-7475 Add a new argument to specify the algorithm applied to LDA, to exhibit the basic usage of LDAOptimizer. cc jkbradley Author: Yuhao Yang <hhbyyh@gmail.com> Closes #6000 from hhbyyh/ldaExample and squashes the following commits: 0a7e2bc [Yuhao Yang] fix according to comments 5810b0f [Yuhao Yang] adjust ldaExample for online LDA
Diffstat (limited to 'examples/src')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala31
1 files changed, 25 insertions, 6 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
index a1850390c0..31d629f853 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -26,7 +26,7 @@ import scopt.OptionParser
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkContext, SparkConf}
-import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA}
+import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
@@ -48,6 +48,7 @@ object LDAExample {
topicConcentration: Double = -1,
vocabSize: Int = 10000,
stopwordFile: String = "",
+ algorithm: String = "em",
checkpointDir: Option[String] = None,
checkpointInterval: Int = 10) extends AbstractParams[Params]
@@ -78,6 +79,10 @@ object LDAExample {
.text(s"filepath for a list of stopwords. Note: This must fit on a single machine." +
s" default: ${defaultParams.stopwordFile}")
.action((x, c) => c.copy(stopwordFile = x))
+ opt[String]("algorithm")
+ .text(s"inference algorithm to use. em and online are supported." +
+ s" default: ${defaultParams.algorithm}")
+ .action((x, c) => c.copy(algorithm = x))
opt[String]("checkpointDir")
.text(s"Directory for checkpointing intermediate results." +
s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." +
@@ -128,7 +133,17 @@ object LDAExample {
// Run LDA.
val lda = new LDA()
- lda.setK(params.k)
+
+ val optimizer = params.algorithm.toLowerCase match {
+ case "em" => new EMLDAOptimizer
+ // add (1.0 / actualCorpusSize) to MiniBatchFraction be more robust on tiny datasets.
+ case "online" => new OnlineLDAOptimizer().setMiniBatchFraction(0.05 + 1.0 / actualCorpusSize)
+ case _ => throw new IllegalArgumentException(
+ s"Only em, online are supported but got ${params.algorithm}.")
+ }
+
+ lda.setOptimizer(optimizer)
+ .setK(params.k)
.setMaxIterations(params.maxIterations)
.setDocConcentration(params.docConcentration)
.setTopicConcentration(params.topicConcentration)
@@ -137,14 +152,18 @@ object LDAExample {
sc.setCheckpointDir(params.checkpointDir.get)
}
val startTime = System.nanoTime()
- val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
+ val ldaModel = lda.run(corpus)
val elapsed = (System.nanoTime() - startTime) / 1e9
println(s"Finished training LDA model. Summary:")
println(s"\t Training time: $elapsed sec")
- val avgLogLikelihood = ldaModel.logLikelihood / actualCorpusSize.toDouble
- println(s"\t Training data average log likelihood: $avgLogLikelihood")
- println()
+
+ if (ldaModel.isInstanceOf[DistributedLDAModel]) {
+ val distLDAModel = ldaModel.asInstanceOf[DistributedLDAModel]
+ val avgLogLikelihood = distLDAModel.logLikelihood / actualCorpusSize.toDouble
+ println(s"\t Training data average log likelihood: $avgLogLikelihood")
+ println()
+ }
// Print the topics, showing the top-weighted terms for each topic.
val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)