aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-20 15:01:31 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-20 15:01:31 -0700
commiteaafe139f881d6105996373c9b11f2ccd91b5b3e (patch)
treebb935d8fb5b9aca4478db5fbec2ffd13a229d490 /mllib/src
parent7cfc0750e14f2c1b3847e4720cc02150253525a9 (diff)
downloadspark-eaafe139f881d6105996373c9b11f2ccd91b5b3e.tar.gz
spark-eaafe139f881d6105996373c9b11f2ccd91b5b3e.tar.bz2
spark-eaafe139f881d6105996373c9b11f2ccd91b5b3e.zip
[SPARK-9245] [MLLIB] LDA topic assignments
For each (document, term) pair, return top topic. Note that instances of (doc, term) pairs within a document (a.k.a. "tokens") are exchangeable, so we should provide an estimate per document-term, rather than per token. CC: rotationsymmetry mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #8329 from jkbradley/lda-topic-assignments.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala51
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala2
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java7
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala21
4 files changed, 74 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index b70e380c03..6bc68a4c18 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -17,7 +17,7 @@
package org.apache.spark.mllib.clustering
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argtopk, normalize, sum}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax, argtopk, normalize, sum}
import breeze.numerics.{exp, lgamma}
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
@@ -438,7 +438,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
Loader.checkSchema[Data](dataFrame.schema)
val topics = dataFrame.collect()
val vocabSize = topics(0).getAs[Vector](0).size
- val k = topics.size
+ val k = topics.length
val brzTopics = BDM.zeros[Double](vocabSize, k)
topics.foreach { case Row(vec: Vector, ind: Int) =>
@@ -610,6 +610,50 @@ class DistributedLDAModel private[clustering] (
}
}
+ /**
+ * Return the top topic for each (doc, term) pair. I.e., for each document, what is the most
+ * likely topic generating each term?
+ *
+ * @return RDD of (doc ID, assignment of top topic index for each term),
+ * where the assignment is specified via a pair of zippable arrays
+ * (term indices, topic indices). Note that terms will be omitted if not present in
+ * the document.
+ */
+ lazy val topicAssignments: RDD[(Long, Array[Int], Array[Int])] = {
+ // For reference, compare the below code with the core part of EMLDAOptimizer.next().
+ val eta = topicConcentration
+ val W = vocabSize
+ val alpha = docConcentration(0)
+ val N_k = globalTopicTotals
+ val sendMsg: EdgeContext[TopicCounts, TokenCount, (Array[Int], Array[Int])] => Unit =
+ (edgeContext) => {
+ // E-STEP: Compute gamma_{wjk} (smoothed topic distributions).
+ val scaledTopicDistribution: TopicCounts =
+ computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha)
+ // For this (doc j, term w), send top topic k to doc vertex.
+ val topTopic: Int = argmax(scaledTopicDistribution)
+ val term: Int = index2term(edgeContext.dstId)
+ edgeContext.sendToSrc((Array(term), Array(topTopic)))
+ }
+ val mergeMsg: ((Array[Int], Array[Int]), (Array[Int], Array[Int])) => (Array[Int], Array[Int]) =
+ (terms_topics0, terms_topics1) => {
+ (terms_topics0._1 ++ terms_topics1._1, terms_topics0._2 ++ terms_topics1._2)
+ }
+ // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
+ val perDocAssignments =
+ graph.aggregateMessages[(Array[Int], Array[Int])](sendMsg, mergeMsg).filter(isDocumentVertex)
+ perDocAssignments.map { case (docID: Long, (terms: Array[Int], topics: Array[Int])) =>
+ // TODO: Avoid zip, which is inefficient.
+ val (sortedTerms, sortedTopics) = terms.zip(topics).sortBy(_._1).unzip
+ (docID, sortedTerms.toArray, sortedTopics.toArray)
+ }
+ }
+
+ /** Java-friendly version of [[topicAssignments]] */
+ lazy val javaTopicAssignments: JavaRDD[(java.lang.Long, Array[Int], Array[Int])] = {
+ topicAssignments.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Int])]].toJavaRDD()
+ }
+
// TODO
// override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
@@ -849,10 +893,9 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
val classNameV1_0 = SaveLoadV1_0.thisClassName
val model = (loadedClassName, loadedVersion) match {
- case (className, "1.0") if className == classNameV1_0 => {
+ case (className, "1.0") if className == classNameV1_0 =>
DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, docConcentration,
topicConcentration, iterationTimes.toArray, gammaShape)
- }
case _ => throw new Exception(
s"DistributedLDAModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 360241c808..cb517f9689 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -167,7 +167,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
edgeContext.sendToDst((false, scaledTopicDistribution))
edgeContext.sendToSrc((false, scaledTopicDistribution))
}
- // This is a hack to detect whether we could modify the values in-place.
+ // The Boolean is a hack to detect whether we could modify the values in-place.
// TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438)
val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) =
(m0, m1) => {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index 6e91cde2ea..3fea359a3b 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -134,6 +134,13 @@ public class JavaLDASuite implements Serializable {
double[] topicWeights = topTopics._3();
assertEquals(3, topicIndices.length);
assertEquals(3, topicWeights.length);
+
+ // Check: topTopicAssignments
+ Tuple3<Long, int[], int[]> topicAssignment = model.javaTopicAssignments().first();
+ Long docId2 = topicAssignment._1();
+ int[] termIndices2 = topicAssignment._2();
+ int[] topicIndices2 = topicAssignment._3();
+ assertEquals(termIndices2.length, topicIndices2.length);
}
@Test
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index 99e28499fd..8a714f9b79 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -135,17 +135,34 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
}
// Top 3 documents per topic
- model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) =>
+ model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach { case (t1, t2) =>
assert(t1._1 === t2._1)
assert(t1._2 === t2._2)
}
// All documents per topic
val q = tinyCorpus.length
- model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) =>
+ model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach { case (t1, t2) =>
assert(t1._1 === t2._1)
assert(t1._2 === t2._2)
}
+
+ // Check: topTopicAssignments
+ // Make sure it assigns a topic to each term appearing in each doc.
+ val topTopicAssignments: Map[Long, (Array[Int], Array[Int])] =
+ model.topicAssignments.collect().map(x => x._1 -> (x._2, x._3)).toMap
+ assert(topTopicAssignments.keys.max < tinyCorpus.length)
+ tinyCorpus.foreach { case (docID: Long, doc: Vector) =>
+ if (topTopicAssignments.contains(docID)) {
+ val (inds, vals) = topTopicAssignments(docID)
+ assert(inds.length === doc.numNonzeros)
+ // For "term" in actual doc,
+ // check that it has a topic assigned.
+ doc.foreachActive((term, wcnt) => assert(wcnt === 0 || inds.contains(term)))
+ } else {
+ assert(doc.numNonzeros === 0)
+ }
+ }
}
test("vertex indexing") {