aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-20 15:01:31 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-20 15:01:31 -0700
commiteaafe139f881d6105996373c9b11f2ccd91b5b3e (patch)
treebb935d8fb5b9aca4478db5fbec2ffd13a229d490 /mllib/src/test
parent7cfc0750e14f2c1b3847e4720cc02150253525a9 (diff)
downloadspark-eaafe139f881d6105996373c9b11f2ccd91b5b3e.tar.gz
spark-eaafe139f881d6105996373c9b11f2ccd91b5b3e.tar.bz2
spark-eaafe139f881d6105996373c9b11f2ccd91b5b3e.zip
[SPARK-9245] [MLLIB] LDA topic assignments
For each (document, term) pair, return top topic. Note that instances of (doc, term) pairs within a document (a.k.a. "tokens") are exchangeable, so we should provide an estimate per document-term, rather than per token. CC: rotationsymmetry mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #8329 from jkbradley/lda-topic-assignments.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java7
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala21
2 files changed, 26 insertions, 2 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index 6e91cde2ea..3fea359a3b 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -134,6 +134,13 @@ public class JavaLDASuite implements Serializable {
double[] topicWeights = topTopics._3();
assertEquals(3, topicIndices.length);
assertEquals(3, topicWeights.length);
+
+ // Check: topTopicAssignments
+ Tuple3<Long, int[], int[]> topicAssignment = model.javaTopicAssignments().first();
+ Long docId2 = topicAssignment._1();
+ int[] termIndices2 = topicAssignment._2();
+ int[] topicIndices2 = topicAssignment._3();
+ assertEquals(termIndices2.length, topicIndices2.length);
}
@Test
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index 99e28499fd..8a714f9b79 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -135,17 +135,34 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
}
// Top 3 documents per topic
- model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) =>
+ model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach { case (t1, t2) =>
assert(t1._1 === t2._1)
assert(t1._2 === t2._2)
}
// All documents per topic
val q = tinyCorpus.length
- model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) =>
+ model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach { case (t1, t2) =>
assert(t1._1 === t2._1)
assert(t1._2 === t2._2)
}
+
+ // Check: topTopicAssignments
+ // Make sure it assigns a topic to each term appearing in each doc.
+ val topTopicAssignments: Map[Long, (Array[Int], Array[Int])] =
+ model.topicAssignments.collect().map(x => x._1 -> (x._2, x._3)).toMap
+ assert(topTopicAssignments.keys.max < tinyCorpus.length)
+ tinyCorpus.foreach { case (docID: Long, doc: Vector) =>
+ if (topTopicAssignments.contains(docID)) {
+ val (inds, vals) = topTopicAssignments(docID)
+ assert(inds.length === doc.numNonzeros)
+ // For "term" in actual doc,
+ // check that it has a topic assigned.
+ doc.foreachActive((term, wcnt) => assert(wcnt === 0 || inds.contains(term)))
+ } else {
+ assert(doc.numNonzeros === 0)
+ }
+ }
}
test("vertex indexing") {