aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-04-27 19:02:51 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-04-27 19:02:51 -0700
commit4d9e560b5470029143926827b1cb9d72a0bfbeff (patch)
tree2507253e2cf6544aefbdca3db8a7b38ae84bb04f
parent62888a4ded91b3c2cbb05936c374c7ebfc10799e (diff)
downloadspark-4d9e560b5470029143926827b1cb9d72a0bfbeff.tar.gz
spark-4d9e560b5470029143926827b1cb9d72a0bfbeff.tar.bz2
spark-4d9e560b5470029143926827b1cb9d72a0bfbeff.zip
[SPARK-7090] [MLLIB] Introduce LDAOptimizer to LDA to further improve extensibility
jira: https://issues.apache.org/jira/browse/SPARK-7090 LDA was implemented with extensibility in mind. And with the development of OnlineLDA and Gibbs Sampling, we are collecting more detailed requirements from different algorithms. As Joseph Bradley jkbradley proposed in https://github.com/apache/spark/pull/4807 and with some further discussion, we'd like to adjust the code structure a little to present the common interface and extension point clearly. Basically class LDA would be a common entrance for LDA computing. And each LDA object will refer to a LDAOptimizer for the concrete algorithm implementation. Users can customize LDAOptimizer with specific parameters and assign it to LDA. Concrete changes: 1. Add a trait `LDAOptimizer`, which defines the common iterface for concrete implementations. Each subClass is a wrapper for a specific LDA algorithm. 2. Move EMOptimizer to file LDAOptimizer and inherits from LDAOptimizer, rename to EMLDAOptimizer. (in case a more generic EMOptimizer comes in the future) -adjust the constructor of EMOptimizer, since all the parameters should be passed in through initialState method. This can avoid unwanted confusion or overwrite. -move the code from LDA.initalState to initalState of EMLDAOptimizer 3. Add property ldaOptimizer to LDA and its getter/setter, and EMLDAOptimizer is the default Optimizer. 4. Change the return type of LDA.run from DistributedLDAModel to LDAModel. Further work: add OnlineLDAOptimizer and other possible Optimizers once ready. Author: Yuhao Yang <hhbyyh@gmail.com> Closes #5661 from hhbyyh/ldaRefactor and squashes the following commits: 0e2e006 [Yuhao Yang] respond to review comments 08a45da [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor e756ce4 [Yuhao Yang] solve mima exception d74fd8f [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor 0bb8400 [Yuhao Yang] refactor LDA with Optimizer ec2f857 [Yuhao Yang] protoptype for discussion
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala181
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala210
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala2
-rw-r--r--project/MimaExcludes.scala4
8 files changed, 256 insertions, 151 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
index 36207ae38d..fd53c81cc4 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
@@ -58,7 +58,7 @@ public class JavaLDAExample {
corpus.cache();
// Cluster the documents into three topics using LDA
- DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus);
+ DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus);
// Output topics. Each is a distribution over words (matching word count vectors)
System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
index 08a93595a2..a1850390c0 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -26,7 +26,7 @@ import scopt.OptionParser
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkContext, SparkConf}
-import org.apache.spark.mllib.clustering.LDA
+import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
@@ -137,7 +137,7 @@ object LDAExample {
sc.setCheckpointDir(params.checkpointDir.get)
}
val startTime = System.nanoTime()
- val ldaModel = lda.run(corpus)
+ val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
val elapsed = (System.nanoTime() - startTime) / 1e9
println(s"Finished training LDA model. Summary:")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index d006b39acb..37bf88b73b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -17,16 +17,11 @@
package org.apache.spark.mllib.clustering
-import java.util.Random
-
-import breeze.linalg.{DenseVector => BDV, normalize}
-
+import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.graphx._
-import org.apache.spark.graphx.impl.GraphImpl
-import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
@@ -42,16 +37,9 @@ import org.apache.spark.util.Utils
* - "token": instance of a term appearing in a document
* - "topic": multinomial distribution over words representing some concept
*
- * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented
- * according to the Asuncion et al. (2009) paper referenced below.
- *
* References:
* - Original LDA paper (journal version):
* Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
- * - This class implements their "smoothed" LDA model.
- * - Paper which clearly explains several algorithms, including EM:
- * Asuncion, Welling, Smyth, and Teh.
- * "On Smoothing and Inference for Topic Models." UAI, 2009.
*
* @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation
* (Wikipedia)]]
@@ -63,10 +51,11 @@ class LDA private (
private var docConcentration: Double,
private var topicConcentration: Double,
private var seed: Long,
- private var checkpointInterval: Int) extends Logging {
+ private var checkpointInterval: Int,
+ private var ldaOptimizer: LDAOptimizer) extends Logging {
def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
- seed = Utils.random.nextLong(), checkpointInterval = 10)
+ seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer)
/**
* Number of topics to infer. I.e., the number of soft cluster centers.
@@ -220,6 +209,32 @@ class LDA private (
this
}
+
+ /** LDAOptimizer used to perform the actual calculation */
+ def getOptimizer: LDAOptimizer = ldaOptimizer
+
+ /**
+ * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer)
+ */
+ def setOptimizer(optimizer: LDAOptimizer): this.type = {
+ this.ldaOptimizer = optimizer
+ this
+ }
+
+ /**
+ * Set the LDAOptimizer used to perform the actual calculation by algorithm name.
+ * Currently "em" is supported.
+ */
+ def setOptimizer(optimizerName: String): this.type = {
+ this.ldaOptimizer =
+ optimizerName.toLowerCase match {
+ case "em" => new EMLDAOptimizer
+ case other =>
+ throw new IllegalArgumentException(s"Only em is supported but got $other.")
+ }
+ this
+ }
+
/**
* Learn an LDA model using the given dataset.
*
@@ -229,9 +244,9 @@ class LDA private (
* Document IDs must be unique and >= 0.
* @return Inferred LDA model
*/
- def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
- val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
- checkpointInterval)
+ def run(documents: RDD[(Long, Vector)]): LDAModel = {
+ val state = ldaOptimizer.initialState(documents, k, getDocConcentration, getTopicConcentration,
+ seed, checkpointInterval)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
@@ -241,12 +256,11 @@ class LDA private (
iterationTimes(iter) = elapsedSeconds
iter += 1
}
- state.graphCheckpointer.deleteAllCheckpoints()
- new DistributedLDAModel(state, iterationTimes)
+ state.getLDAModel(iterationTimes)
}
/** Java-friendly version of [[run()]] */
- def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = {
+ def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = {
run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
}
}
@@ -321,87 +335,9 @@ private[clustering] object LDA {
private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0
/**
- * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters.
- *
- * @param graph EM graph, storing current parameter estimates in vertex descriptors and
- * data (token counts) in edge descriptors.
- * @param k Number of topics
- * @param vocabSize Number of unique terms
- * @param docConcentration "alpha"
- * @param topicConcentration "beta" or "eta"
- */
- private[clustering] class EMOptimizer(
- var graph: Graph[TopicCounts, TokenCount],
- val k: Int,
- val vocabSize: Int,
- val docConcentration: Double,
- val topicConcentration: Double,
- checkpointInterval: Int) {
-
- private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
- graph, checkpointInterval)
-
- def next(): EMOptimizer = {
- val eta = topicConcentration
- val W = vocabSize
- val alpha = docConcentration
-
- val N_k = globalTopicTotals
- val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit =
- (edgeContext) => {
- // Compute N_{wj} gamma_{wjk}
- val N_wj = edgeContext.attr
- // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count
- // N_{wj}.
- val scaledTopicDistribution: TopicCounts =
- computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj
- edgeContext.sendToDst((false, scaledTopicDistribution))
- edgeContext.sendToSrc((false, scaledTopicDistribution))
- }
- // This is a hack to detect whether we could modify the values in-place.
- // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438)
- val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) =
- (m0, m1) => {
- val sum =
- if (m0._1) {
- m0._2 += m1._2
- } else if (m1._1) {
- m1._2 += m0._2
- } else {
- m0._2 + m1._2
- }
- (true, sum)
- }
- // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
- val docTopicDistributions: VertexRDD[TopicCounts] =
- graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg)
- .mapValues(_._2)
- // Update the vertex descriptors with the new counts.
- val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
- graph = newGraph
- graphCheckpointer.updateGraph(newGraph)
- globalTopicTotals = computeGlobalTopicTotals()
- this
- }
-
- /**
- * Aggregate distributions over topics from all term vertices.
- *
- * Note: This executes an action on the graph RDDs.
- */
- var globalTopicTotals: TopicCounts = computeGlobalTopicTotals()
-
- private def computeGlobalTopicTotals(): TopicCounts = {
- val numTopics = k
- graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _)
- }
-
- }
-
- /**
* Compute gamma_{wjk}, a distribution over topics k.
*/
- private def computePTopic(
+ private[clustering] def computePTopic(
docTopicCounts: TopicCounts,
termTopicCounts: TopicCounts,
totalTopicCounts: TopicCounts,
@@ -427,49 +363,4 @@ private[clustering] object LDA {
// normalize
BDV(gamma_wj) /= sum
}
-
- /**
- * Compute bipartite term/doc graph.
- */
- private def initialState(
- docs: RDD[(Long, Vector)],
- k: Int,
- docConcentration: Double,
- topicConcentration: Double,
- randomSeed: Long,
- checkpointInterval: Int): EMOptimizer = {
- // For each document, create an edge (Document -> Term) for each unique term in the document.
- val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) =>
- // Add edges for terms with non-zero counts.
- termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) =>
- Edge(docID, term2index(term), cnt)
- }
- }
-
- val vocabSize = docs.take(1).head._2.size
-
- // Create vertices.
- // Initially, we use random soft assignments of tokens to topics (random gamma).
- def createVertices(): RDD[(VertexId, TopicCounts)] = {
- val verticesTMP: RDD[(VertexId, TopicCounts)] =
- edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
- val random = new Random(partIndex + randomSeed)
- partEdges.flatMap { edge =>
- val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)
- val sum = gamma * edge.attr
- Seq((edge.srcId, sum), (edge.dstId, sum))
- }
- }
- verticesTMP.reduceByKey(_ + _)
- }
-
- val docTermVertices = createVertices()
-
- // Partition such that edges are grouped by document
- val graph = Graph(docTermVertices, edges)
- .partitionBy(PartitionStrategy.EdgePartition1D)
-
- new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval)
- }
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 0a3f21ecee..6cf26445f2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -203,7 +203,7 @@ class DistributedLDAModel private (
import LDA._
- private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = {
+ private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = {
this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration,
state.topicConcentration, iterationTimes)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
new file mode 100644
index 0000000000..ffd72a294c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import java.util.Random
+
+import breeze.linalg.{DenseVector => BDV, normalize}
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.impl.GraphImpl
+import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: Experimental ::
+ *
+ * An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can
+ * hold optimizer-specific parameters for users to set.
+ */
+@Experimental
+trait LDAOptimizer{
+
+ /*
+ DEVELOPERS NOTE:
+
+ An LDAOptimizer contains an algorithm for LDA and performs the actual computation, which
+ stores internal data structure (Graph or Matrix) and other parameters for the algorithm.
+ The interface is isolated to improve the extensibility of LDA.
+ */
+
+ /**
+ * Initializer for the optimizer. LDA passes the common parameters to the optimizer and
+ * the internal structure can be initialized properly.
+ */
+ private[clustering] def initialState(
+ docs: RDD[(Long, Vector)],
+ k: Int,
+ docConcentration: Double,
+ topicConcentration: Double,
+ randomSeed: Long,
+ checkpointInterval: Int): LDAOptimizer
+
+ private[clustering] def next(): LDAOptimizer
+
+ private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel
+}
+
+/**
+ * :: Experimental ::
+ *
+ * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters.
+ *
+ * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented
+ * according to the Asuncion et al. (2009) paper referenced below.
+ *
+ * References:
+ * - Original LDA paper (journal version):
+ * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
+ * - This class implements their "smoothed" LDA model.
+ * - Paper which clearly explains several algorithms, including EM:
+ * Asuncion, Welling, Smyth, and Teh.
+ * "On Smoothing and Inference for Topic Models." UAI, 2009.
+ *
+ */
+@Experimental
+class EMLDAOptimizer extends LDAOptimizer{
+
+ import LDA._
+
+ /**
+ * Following fields will only be initialized through initialState method
+ */
+ private[clustering] var graph: Graph[TopicCounts, TokenCount] = null
+ private[clustering] var k: Int = 0
+ private[clustering] var vocabSize: Int = 0
+ private[clustering] var docConcentration: Double = 0
+ private[clustering] var topicConcentration: Double = 0
+ private[clustering] var checkpointInterval: Int = 10
+ private var graphCheckpointer: PeriodicGraphCheckpointer[TopicCounts, TokenCount] = null
+
+ /**
+ * Compute bipartite term/doc graph.
+ */
+ private[clustering] override def initialState(
+ docs: RDD[(Long, Vector)],
+ k: Int,
+ docConcentration: Double,
+ topicConcentration: Double,
+ randomSeed: Long,
+ checkpointInterval: Int): LDAOptimizer = {
+ // For each document, create an edge (Document -> Term) for each unique term in the document.
+ val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) =>
+ // Add edges for terms with non-zero counts.
+ termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) =>
+ Edge(docID, term2index(term), cnt)
+ }
+ }
+
+ val vocabSize = docs.take(1).head._2.size
+
+ // Create vertices.
+ // Initially, we use random soft assignments of tokens to topics (random gamma).
+ def createVertices(): RDD[(VertexId, TopicCounts)] = {
+ val verticesTMP: RDD[(VertexId, TopicCounts)] =
+ edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
+ val random = new Random(partIndex + randomSeed)
+ partEdges.flatMap { edge =>
+ val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)
+ val sum = gamma * edge.attr
+ Seq((edge.srcId, sum), (edge.dstId, sum))
+ }
+ }
+ verticesTMP.reduceByKey(_ + _)
+ }
+
+ val docTermVertices = createVertices()
+
+ // Partition such that edges are grouped by document
+ this.graph = Graph(docTermVertices, edges).partitionBy(PartitionStrategy.EdgePartition1D)
+ this.k = k
+ this.vocabSize = vocabSize
+ this.docConcentration = docConcentration
+ this.topicConcentration = topicConcentration
+ this.checkpointInterval = checkpointInterval
+ this.graphCheckpointer = new
+ PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval)
+ this.globalTopicTotals = computeGlobalTopicTotals()
+ this
+ }
+
+ private[clustering] override def next(): EMLDAOptimizer = {
+ require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
+
+ val eta = topicConcentration
+ val W = vocabSize
+ val alpha = docConcentration
+
+ val N_k = globalTopicTotals
+ val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit =
+ (edgeContext) => {
+ // Compute N_{wj} gamma_{wjk}
+ val N_wj = edgeContext.attr
+ // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count
+ // N_{wj}.
+ val scaledTopicDistribution: TopicCounts =
+ computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj
+ edgeContext.sendToDst((false, scaledTopicDistribution))
+ edgeContext.sendToSrc((false, scaledTopicDistribution))
+ }
+ // This is a hack to detect whether we could modify the values in-place.
+ // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438)
+ val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) =
+ (m0, m1) => {
+ val sum =
+ if (m0._1) {
+ m0._2 += m1._2
+ } else if (m1._1) {
+ m1._2 += m0._2
+ } else {
+ m0._2 + m1._2
+ }
+ (true, sum)
+ }
+ // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
+ val docTopicDistributions: VertexRDD[TopicCounts] =
+ graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg)
+ .mapValues(_._2)
+ // Update the vertex descriptors with the new counts.
+ val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
+ graph = newGraph
+ graphCheckpointer.updateGraph(newGraph)
+ globalTopicTotals = computeGlobalTopicTotals()
+ this
+ }
+
+ /**
+ * Aggregate distributions over topics from all term vertices.
+ *
+ * Note: This executes an action on the graph RDDs.
+ */
+ private[clustering] var globalTopicTotals: TopicCounts = null
+
+ private def computeGlobalTopicTotals(): TopicCounts = {
+ val numTopics = k
+ graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _)
+ }
+
+ private[clustering] override def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
+ require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
+ this.graphCheckpointer.deleteAllCheckpoints()
+ new DistributedLDAModel(this, iterationTimes)
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index dc10aa67c7..fbe171b4b1 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -88,7 +88,7 @@ public class JavaLDASuite implements Serializable {
.setMaxIterations(5)
.setSeed(12345);
- DistributedLDAModel model = lda.run(corpus);
+ DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus);
// Check: basic parameters
LocalLDAModel localModel = model.toLocal();
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index cc747dabb9..41ec794146 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -68,7 +68,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
.setSeed(12345)
val corpus = sc.parallelize(tinyCorpus, 2)
- val model: DistributedLDAModel = lda.run(corpus)
+ val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
// Check: basic parameters
val localModel = model.toLocal
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 7ef363a2f0..967961c2bf 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -72,6 +72,10 @@ object MimaExcludes {
// SPARK-6703 Add getOrCreate method to SparkContext
ProblemFilters.exclude[IncompatibleResultTypeProblem]
("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext")
+ )++ Seq(
+ // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.mllib.clustering.LDA$EMOptimizer")
)
case v if v.startsWith("1.3") =>