aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMeihua Wu <meihuawu@umich.edu>2015-07-31 13:01:10 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-31 13:01:10 -0700
commit3c0d2e55210735e0df2f8febb5f63c224af230e3 (patch)
tree34726b0b26e6f5e243fab0ff15a56485f8a3ce9c /mllib
parentc0686668ae6a92b6bb4801a55c3b78aedbee816a (diff)
downloadspark-3c0d2e55210735e0df2f8febb5f63c224af230e3.tar.gz
spark-3c0d2e55210735e0df2f8febb5f63c224af230e3.tar.bz2
spark-3c0d2e55210735e0df2f8febb5f63c224af230e3.zip
[SPARK-9246] [MLLIB] DistributedLDAModel predict top docs per topic
Add topDocumentsPerTopic to DistributedLDAModel. Add ScalaDoc and unit tests. Author: Meihua Wu <meihuawu@umich.edu> Closes #7769 from rotationsymmetry/SPARK-9246 and squashes the following commits: 1029e79c [Meihua Wu] clean up code comments a023b82 [Meihua Wu] Update tests to use Long for doc index. 91e5998 [Meihua Wu] Use Long for doc index. b9f70cf [Meihua Wu] Revise topDocumentsPerTopic 26ff3f6 [Meihua Wu] Add topDocumentsPerTopic, scala doc and unit tests
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala37
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala22
2 files changed, 59 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index ff7035d224..0cdac84eeb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -516,6 +516,43 @@ class DistributedLDAModel private[clustering] (
}
}
+ /**
+ * Return the top documents for each topic
+ *
+ * This is approximate; it may not return exactly the top-weighted documents for each topic.
+ * To get a more precise set of top documents, increase maxDocumentsPerTopic.
+ *
+ * @param maxDocumentsPerTopic Maximum number of documents to collect for each topic.
+ * @return Array over topics. Each element represent as a pair of matching arrays:
+ * (IDs for the documents, weights of the topic in these documents).
+ * For each topic, documents are sorted in order of decreasing topic weights.
+ */
+ def topDocumentsPerTopic(maxDocumentsPerTopic: Int): Array[(Array[Long], Array[Double])] = {
+ val numTopics = k
+ val topicsInQueues: Array[BoundedPriorityQueue[(Double, Long)]] =
+ topicDistributions.mapPartitions { docVertices =>
+ // For this partition, collect the most common docs for each topic in queues:
+ // queues(topic) = queue of (doc topic, doc ID).
+ val queues =
+ Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Long)](maxDocumentsPerTopic))
+ for ((docId, docTopics) <- docVertices) {
+ var topic = 0
+ while (topic < numTopics) {
+ queues(topic) += (docTopics(topic) -> docId)
+ topic += 1
+ }
+ }
+ Iterator(queues)
+ }.treeReduce { (q1, q2) =>
+ q1.zip(q2).foreach { case (a, b) => a ++= b }
+ q1
+ }
+ topicsInQueues.map { q =>
+ val (docTopics, docs) = q.toArray.sortBy(-_._1).unzip
+ (docs.toArray, docTopics.toArray)
+ }
+ }
+
// TODO
// override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index 79d2a1cafd..f2b94707fd 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -122,6 +122,28 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
// Check: log probabilities
assert(model.logLikelihood < 0.0)
assert(model.logPrior < 0.0)
+
+ // Check: topDocumentsPerTopic
+ // Compare it with top documents per topic derived from topicDistributions
+ val topDocsByTopicDistributions = { n: Int =>
+ Range(0, k).map { topic =>
+ val (doc, docWeights) = topicDistributions.sortBy(-_._2(topic)).take(n).unzip
+ (doc.toArray, docWeights.map(_(topic)).toArray)
+ }.toArray
+ }
+
+ // Top 3 documents per topic
+ model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) =>
+ assert(t1._1 === t2._1)
+ assert(t1._2 === t2._2)
+ }
+
+ // All documents per topic
+ val q = tinyCorpus.length
+ model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) =>
+ assert(t1._1 === t2._1)
+ assert(t1._2 === t2._2)
+ }
}
test("vertex indexing") {