diff options
author | Meihua Wu <meihuawu@umich.edu> | 2015-07-31 13:01:10 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-07-31 13:01:10 -0700 |
commit | 3c0d2e55210735e0df2f8febb5f63c224af230e3 (patch) | |
tree | 34726b0b26e6f5e243fab0ff15a56485f8a3ce9c /mllib/src/test | |
parent | c0686668ae6a92b6bb4801a55c3b78aedbee816a (diff) | |
download | spark-3c0d2e55210735e0df2f8febb5f63c224af230e3.tar.gz spark-3c0d2e55210735e0df2f8febb5f63c224af230e3.tar.bz2 spark-3c0d2e55210735e0df2f8febb5f63c224af230e3.zip |
[SPARK-9246] [MLLIB] DistributedLDAModel predict top docs per topic
Add topDocumentsPerTopic to DistributedLDAModel.
Add ScalaDoc and unit tests.
Author: Meihua Wu <meihuawu@umich.edu>
Closes #7769 from rotationsymmetry/SPARK-9246 and squashes the following commits:
1029e79c [Meihua Wu] clean up code comments
a023b82 [Meihua Wu] Update tests to use Long for doc index.
91e5998 [Meihua Wu] Use Long for doc index.
b9f70cf [Meihua Wu] Revise topDocumentsPerTopic
26ff3f6 [Meihua Wu] Add topDocumentsPerTopic, scala doc and unit tests
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 79d2a1cafd..f2b94707fd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -122,6 +122,28 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // Check: log probabilities assert(model.logLikelihood < 0.0) assert(model.logPrior < 0.0) + + // Check: topDocumentsPerTopic + // Compare it with top documents per topic derived from topicDistributions + val topDocsByTopicDistributions = { n: Int => + Range(0, k).map { topic => + val (doc, docWeights) = topicDistributions.sortBy(-_._2(topic)).take(n).unzip + (doc.toArray, docWeights.map(_(topic)).toArray) + }.toArray + } + + // Top 3 documents per topic + model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) => + assert(t1._1 === t2._1) + assert(t1._2 === t2._2) + } + + // All documents per topic + val q = tinyCorpus.length + model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) => + assert(t1._1 === t2._1) + assert(t1._2 === t2._2) + } } test("vertex indexing") { |