aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-04-27 19:02:51 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-04-27 19:02:51 -0700
commit4d9e560b5470029143926827b1cb9d72a0bfbeff (patch)
tree2507253e2cf6544aefbdca3db8a7b38ae84bb04f /examples
parent62888a4ded91b3c2cbb05936c374c7ebfc10799e (diff)
downloadspark-4d9e560b5470029143926827b1cb9d72a0bfbeff.tar.gz
spark-4d9e560b5470029143926827b1cb9d72a0bfbeff.tar.bz2
spark-4d9e560b5470029143926827b1cb9d72a0bfbeff.zip
[SPARK-7090] [MLLIB] Introduce LDAOptimizer to LDA to further improve extensibility
jira: https://issues.apache.org/jira/browse/SPARK-7090 LDA was implemented with extensibility in mind. And with the development of OnlineLDA and Gibbs Sampling, we are collecting more detailed requirements from different algorithms. As Joseph Bradley jkbradley proposed in https://github.com/apache/spark/pull/4807 and with some further discussion, we'd like to adjust the code structure a little to present the common interface and extension point clearly. Basically class LDA would be a common entrance for LDA computing. And each LDA object will refer to a LDAOptimizer for the concrete algorithm implementation. Users can customize LDAOptimizer with specific parameters and assign it to LDA. Concrete changes: 1. Add a trait `LDAOptimizer`, which defines the common iterface for concrete implementations. Each subClass is a wrapper for a specific LDA algorithm. 2. Move EMOptimizer to file LDAOptimizer and inherits from LDAOptimizer, rename to EMLDAOptimizer. (in case a more generic EMOptimizer comes in the future) -adjust the constructor of EMOptimizer, since all the parameters should be passed in through initialState method. This can avoid unwanted confusion or overwrite. -move the code from LDA.initalState to initalState of EMLDAOptimizer 3. Add property ldaOptimizer to LDA and its getter/setter, and EMLDAOptimizer is the default Optimizer. 4. Change the return type of LDA.run from DistributedLDAModel to LDAModel. Further work: add OnlineLDAOptimizer and other possible Optimizers once ready. Author: Yuhao Yang <hhbyyh@gmail.com> Closes #5661 from hhbyyh/ldaRefactor and squashes the following commits: 0e2e006 [Yuhao Yang] respond to review comments 08a45da [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor e756ce4 [Yuhao Yang] solve mima exception d74fd8f [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor 0bb8400 [Yuhao Yang] refactor LDA with Optimizer ec2f857 [Yuhao Yang] protoptype for discussion
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala4
2 files changed, 3 insertions, 3 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
index 36207ae38d..fd53c81cc4 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
@@ -58,7 +58,7 @@ public class JavaLDAExample {
corpus.cache();
// Cluster the documents into three topics using LDA
- DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus);
+ DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus);
// Output topics. Each is a distribution over words (matching word count vectors)
System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
index 08a93595a2..a1850390c0 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -26,7 +26,7 @@ import scopt.OptionParser
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkContext, SparkConf}
-import org.apache.spark.mllib.clustering.LDA
+import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
@@ -137,7 +137,7 @@ object LDAExample {
sc.setCheckpointDir(params.checkpointDir.get)
}
val startTime = System.nanoTime()
- val ldaModel = lda.run(corpus)
+ val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
val elapsed = (System.nanoTime() - startTime) / 1e9
println(s"Finished training LDA model. Summary:")