aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-07-31 11:50:15 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-31 11:50:15 -0700
commit4011a947154d97a9ffb5a71f077481a12534d36b (patch)
treeb117215285eae619072afca425a0c35ec9b1d960 /mllib
parent6add4eddb39e7748a87da3e921ea3c7881d30a82 (diff)
downloadspark-4011a947154d97a9ffb5a71f077481a12534d36b.tar.gz
spark-4011a947154d97a9ffb5a71f077481a12534d36b.tar.bz2
spark-4011a947154d97a9ffb5a71f077481a12534d36b.zip
[SPARK-9231] [MLLIB] DistributedLDAModel method for top topics per document
jira: https://issues.apache.org/jira/browse/SPARK-9231 Helper method in DistributedLDAModel of this form: ``` /** * For each document, return the top k weighted topics for that document. * return RDD of (doc ID, topic indices, topic weights) */ def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] ``` Author: Yuhao Yang <hhbyyh@gmail.com> Closes #7785 from hhbyyh/topTopicsPerdoc and squashes the following commits: 30ad153 [Yuhao Yang] small fix fd24580 [Yuhao Yang] add topTopics per document to DistributedLDAModel
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala19
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala13
2 files changed, 30 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 6cfad3fbbd..82281a0daf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -17,7 +17,7 @@
package org.apache.spark.mllib.clustering
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argtopk, normalize, sum}
import breeze.numerics.{exp, lgamma}
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
@@ -591,6 +591,23 @@ class DistributedLDAModel private[clustering] (
JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
}
+ /**
+ * For each document, return the top k weighted topics for that document and their weights.
+ * @return RDD of (doc ID, topic indices, topic weights)
+ */
+ def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = {
+ graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) =>
+ val topIndices = argtopk(topicCounts, k)
+ val sumCounts = sum(topicCounts)
+ val weights = if (sumCounts != 0) {
+ topicCounts(topIndices) / sumCounts
+ } else {
+ topicCounts(topIndices)
+ }
+ (docID.toLong, topIndices.toArray, weights.toArray)
+ }
+ }
+
// TODO:
// override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index c43e1e575c..695ee3b82e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.mllib.clustering
-import breeze.linalg.{DenseMatrix => BDM, max, argmax}
+import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax}
import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx.Edge
@@ -108,6 +108,17 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5)
}
+ val top2TopicsPerDoc = model.topTopicsPerDocument(2).map(t => (t._1, (t._2, t._3)))
+ model.topicDistributions.join(top2TopicsPerDoc).collect().foreach {
+ case (docId, (topicDistribution, (indices, weights))) =>
+ assert(indices.length == 2)
+ assert(weights.length == 2)
+ val bdvTopicDist = topicDistribution.toBreeze
+ val top2Indices = argtopk(bdvTopicDist, 2)
+ assert(top2Indices.toArray === indices)
+ assert(bdvTopicDist(top2Indices).toArray === weights)
+ }
+
// Check: log probabilities
assert(model.logLikelihood < 0.0)
assert(model.logPrior < 0.0)