aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-07-31 11:50:15 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-31 11:50:15 -0700
commit4011a947154d97a9ffb5a71f077481a12534d36b (patch)
treeb117215285eae619072afca425a0c35ec9b1d960 /mllib/src/test
parent6add4eddb39e7748a87da3e921ea3c7881d30a82 (diff)
downloadspark-4011a947154d97a9ffb5a71f077481a12534d36b.tar.gz
spark-4011a947154d97a9ffb5a71f077481a12534d36b.tar.bz2
spark-4011a947154d97a9ffb5a71f077481a12534d36b.zip
[SPARK-9231] [MLLIB] DistributedLDAModel method for top topics per document
jira: https://issues.apache.org/jira/browse/SPARK-9231 Helper method in DistributedLDAModel of this form: ``` /** * For each document, return the top k weighted topics for that document. * return RDD of (doc ID, topic indices, topic weights) */ def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] ``` Author: Yuhao Yang <hhbyyh@gmail.com> Closes #7785 from hhbyyh/topTopicsPerdoc and squashes the following commits: 30ad153 [Yuhao Yang] small fix fd24580 [Yuhao Yang] add topTopics per document to DistributedLDAModel
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala13
1 files changed, 12 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index c43e1e575c..695ee3b82e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.mllib.clustering
-import breeze.linalg.{DenseMatrix => BDM, max, argmax}
+import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax}
import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx.Edge
@@ -108,6 +108,17 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5)
}
+ val top2TopicsPerDoc = model.topTopicsPerDocument(2).map(t => (t._1, (t._2, t._3)))
+ model.topicDistributions.join(top2TopicsPerDoc).collect().foreach {
+ case (docId, (topicDistribution, (indices, weights))) =>
+ assert(indices.length == 2)
+ assert(weights.length == 2)
+ val bdvTopicDist = topicDistribution.toBreeze
+ val top2Indices = argtopk(bdvTopicDist, 2)
+ assert(top2Indices.toArray === indices)
+ assert(bdvTopicDist(top2Indices).toArray === weights)
+ }
+
// Check: log probabilities
assert(model.logLikelihood < 0.0)
assert(model.logPrior < 0.0)