From eaafe139f881d6105996373c9b11f2ccd91b5b3e Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 20 Aug 2015 15:01:31 -0700 Subject: [SPARK-9245] [MLLIB] LDA topic assignments For each (document, term) pair, return top topic. Note that instances of (doc, term) pairs within a document (a.k.a. "tokens") are exchangeable, so we should provide an estimate per document-term, rather than per token. CC: rotationsymmetry mengxr Author: Joseph K. Bradley Closes #8329 from jkbradley/lda-topic-assignments. --- .../apache/spark/mllib/clustering/LDAModel.scala | 51 ++++++++++++++++++++-- .../spark/mllib/clustering/LDAOptimizer.scala | 2 +- .../spark/mllib/clustering/JavaLDASuite.java | 7 +++ .../apache/spark/mllib/clustering/LDASuite.scala | 21 ++++++++- 4 files changed, 74 insertions(+), 7 deletions(-) (limited to 'mllib/src') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index b70e380c03..6bc68a4c18 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argtopk, normalize, sum} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax, argtopk, normalize, sum} import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -438,7 +438,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { Loader.checkSchema[Data](dataFrame.schema) val topics = dataFrame.collect() val vocabSize = topics(0).getAs[Vector](0).size - val k = topics.size + val k = topics.length val brzTopics = BDM.zeros[Double](vocabSize, k) topics.foreach { case Row(vec: Vector, ind: Int) => @@ -610,6 +610,50 @@ class DistributedLDAModel private[clustering] ( } } + /** + * Return the top topic for each (doc, term) pair. I.e., for each document, what is the most + * likely topic generating each term? + * + * @return RDD of (doc ID, assignment of top topic index for each term), + * where the assignment is specified via a pair of zippable arrays + * (term indices, topic indices). Note that terms will be omitted if not present in + * the document. + */ + lazy val topicAssignments: RDD[(Long, Array[Int], Array[Int])] = { + // For reference, compare the below code with the core part of EMLDAOptimizer.next(). + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration(0) + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Array[Int], Array[Int])] => Unit = + (edgeContext) => { + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions). + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) + // For this (doc j, term w), send top topic k to doc vertex. + val topTopic: Int = argmax(scaledTopicDistribution) + val term: Int = index2term(edgeContext.dstId) + edgeContext.sendToSrc((Array(term), Array(topTopic))) + } + val mergeMsg: ((Array[Int], Array[Int]), (Array[Int], Array[Int])) => (Array[Int], Array[Int]) = + (terms_topics0, terms_topics1) => { + (terms_topics0._1 ++ terms_topics1._1, terms_topics0._2 ++ terms_topics1._2) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val perDocAssignments = + graph.aggregateMessages[(Array[Int], Array[Int])](sendMsg, mergeMsg).filter(isDocumentVertex) + perDocAssignments.map { case (docID: Long, (terms: Array[Int], topics: Array[Int])) => + // TODO: Avoid zip, which is inefficient. + val (sortedTerms, sortedTopics) = terms.zip(topics).sortBy(_._1).unzip + (docID, sortedTerms.toArray, sortedTopics.toArray) + } + } + + /** Java-friendly version of [[topicAssignments]] */ + lazy val javaTopicAssignments: JavaRDD[(java.lang.Long, Array[Int], Array[Int])] = { + topicAssignments.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Int])]].toJavaRDD() + } + // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? @@ -849,10 +893,9 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { - case (className, "1.0") if className == classNameV1_0 => { + case (className, "1.0") if className == classNameV1_0 => DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray, gammaShape) - } case _ => throw new Exception( s"DistributedLDAModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 360241c808..cb517f9689 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -167,7 +167,7 @@ final class EMLDAOptimizer extends LDAOptimizer { edgeContext.sendToDst((false, scaledTopicDistribution)) edgeContext.sendToSrc((false, scaledTopicDistribution)) } - // This is a hack to detect whether we could modify the values in-place. + // The Boolean is a hack to detect whether we could modify the values in-place. // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = (m0, m1) => { diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 6e91cde2ea..3fea359a3b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -134,6 +134,13 @@ public class JavaLDASuite implements Serializable { double[] topicWeights = topTopics._3(); assertEquals(3, topicIndices.length); assertEquals(3, topicWeights.length); + + // Check: topTopicAssignments + Tuple3 topicAssignment = model.javaTopicAssignments().first(); + Long docId2 = topicAssignment._1(); + int[] termIndices2 = topicAssignment._2(); + int[] topicIndices2 = topicAssignment._3(); + assertEquals(termIndices2.length, topicIndices2.length); } @Test diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 99e28499fd..8a714f9b79 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -135,17 +135,34 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } // Top 3 documents per topic - model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) => + model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach { case (t1, t2) => assert(t1._1 === t2._1) assert(t1._2 === t2._2) } // All documents per topic val q = tinyCorpus.length - model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) => + model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach { case (t1, t2) => assert(t1._1 === t2._1) assert(t1._2 === t2._2) } + + // Check: topTopicAssignments + // Make sure it assigns a topic to each term appearing in each doc. + val topTopicAssignments: Map[Long, (Array[Int], Array[Int])] = + model.topicAssignments.collect().map(x => x._1 -> (x._2, x._3)).toMap + assert(topTopicAssignments.keys.max < tinyCorpus.length) + tinyCorpus.foreach { case (docID: Long, doc: Vector) => + if (topTopicAssignments.contains(docID)) { + val (inds, vals) = topTopicAssignments(docID) + assert(inds.length === doc.numNonzeros) + // For "term" in actual doc, + // check that it has a topic assigned. + doc.foreachActive((term, wcnt) => assert(wcnt === 0 || inds.contains(term))) + } else { + assert(doc.numNonzeros === 0) + } + } } test("vertex indexing") { -- cgit v1.2.3