aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-07-31 11:50:15 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-31 11:50:15 -0700
commit4011a947154d97a9ffb5a71f077481a12534d36b (patch)
treeb117215285eae619072afca425a0c35ec9b1d960 /mllib/src/main
parent6add4eddb39e7748a87da3e921ea3c7881d30a82 (diff)
downloadspark-4011a947154d97a9ffb5a71f077481a12534d36b.tar.gz
spark-4011a947154d97a9ffb5a71f077481a12534d36b.tar.bz2
spark-4011a947154d97a9ffb5a71f077481a12534d36b.zip
[SPARK-9231] [MLLIB] DistributedLDAModel method for top topics per document
jira: https://issues.apache.org/jira/browse/SPARK-9231 Helper method in DistributedLDAModel of this form: ``` /** * For each document, return the top k weighted topics for that document. * return RDD of (doc ID, topic indices, topic weights) */ def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] ``` Author: Yuhao Yang <hhbyyh@gmail.com> Closes #7785 from hhbyyh/topTopicsPerdoc and squashes the following commits: 30ad153 [Yuhao Yang] small fix fd24580 [Yuhao Yang] add topTopics per document to DistributedLDAModel
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala19
1 files changed, 18 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 6cfad3fbbd..82281a0daf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -17,7 +17,7 @@
package org.apache.spark.mllib.clustering
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argtopk, normalize, sum}
import breeze.numerics.{exp, lgamma}
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
@@ -591,6 +591,23 @@ class DistributedLDAModel private[clustering] (
JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
}
+ /**
+ * For each document, return the top k weighted topics for that document and their weights.
+ * @return RDD of (doc ID, topic indices, topic weights)
+ */
+ def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = {
+ graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) =>
+ val topIndices = argtopk(topicCounts, k)
+ val sumCounts = sum(topicCounts)
+ val weights = if (sumCounts != 0) {
+ topicCounts(topIndices) / sumCounts
+ } else {
+ topicCounts(topIndices)
+ }
+ (docID.toLong, topIndices.toArray, weights.toArray)
+ }
+ }
+
// TODO:
// override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???