aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorMeihua Wu <meihuawu@umich.edu>2015-07-31 13:01:10 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-31 13:01:10 -0700
commit3c0d2e55210735e0df2f8febb5f63c224af230e3 (patch)
tree34726b0b26e6f5e243fab0ff15a56485f8a3ce9c /mllib/src/test
parentc0686668ae6a92b6bb4801a55c3b78aedbee816a (diff)
downloadspark-3c0d2e55210735e0df2f8febb5f63c224af230e3.tar.gz
spark-3c0d2e55210735e0df2f8febb5f63c224af230e3.tar.bz2
spark-3c0d2e55210735e0df2f8febb5f63c224af230e3.zip
[SPARK-9246] [MLLIB] DistributedLDAModel predict top docs per topic
Add topDocumentsPerTopic to DistributedLDAModel. Add ScalaDoc and unit tests. Author: Meihua Wu <meihuawu@umich.edu> Closes #7769 from rotationsymmetry/SPARK-9246 and squashes the following commits: 1029e79c [Meihua Wu] clean up code comments a023b82 [Meihua Wu] Update tests to use Long for doc index. 91e5998 [Meihua Wu] Use Long for doc index. b9f70cf [Meihua Wu] Revise topDocumentsPerTopic 26ff3f6 [Meihua Wu] Add topDocumentsPerTopic, scala doc and unit tests
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala22
1 files changed, 22 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index 79d2a1cafd..f2b94707fd 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -122,6 +122,28 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
// Check: log probabilities
assert(model.logLikelihood < 0.0)
assert(model.logPrior < 0.0)
+
+ // Check: topDocumentsPerTopic
+ // Compare it with top documents per topic derived from topicDistributions
+ val topDocsByTopicDistributions = { n: Int =>
+ Range(0, k).map { topic =>
+ val (doc, docWeights) = topicDistributions.sortBy(-_._2(topic)).take(n).unzip
+ (doc.toArray, docWeights.map(_(topic)).toArray)
+ }.toArray
+ }
+
+ // Top 3 documents per topic
+ model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) =>
+ assert(t1._1 === t2._1)
+ assert(t1._2 === t2._2)
+ }
+
+ // All documents per topic
+ val q = tinyCorpus.length
+ model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) =>
+ assert(t1._1 === t2._1)
+ assert(t1._2 === t2._2)
+ }
}
test("vertex indexing") {