aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2016-01-11 14:55:44 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-11 14:55:44 -0800
commitbbea88852ce6a3127d071ca40dbca2d042f9fbcf (patch)
tree7aab3ab1a41d5589c6ac7bbf2b3c3b7d7c647512 /mllib/src/main
parent4f8eefa36bb90812aac61ac7a762c9452de666bf (diff)
downloadspark-bbea88852ce6a3127d071ca40dbca2d042f9fbcf.tar.gz
spark-bbea88852ce6a3127d071ca40dbca2d042f9fbcf.tar.bz2
spark-bbea88852ce6a3127d071ca40dbca2d042f9fbcf.zip
[SPARK-10809][MLLIB] Single-document topicDistributions method for LocalLDAModel
jira: https://issues.apache.org/jira/browse/SPARK-10809 We could provide a single-document topicDistributions method for LocalLDAModel to allow for quick queries which avoid RDD operations. Currently, the user must use an RDD of documents. add some missing assert too. Author: Yuhao Yang <hhbyyh@gmail.com> Closes #9484 from hhbyyh/ldaTopicPre.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala26
1 files changed, 26 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 2fce3ff641..b30ecb8020 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -388,6 +388,32 @@ class LocalLDAModel private[spark] (
}
/**
+ * Predicts the topic mixture distribution for a document (often called "theta" in the
+ * literature). Returns a vector of zeros for an empty document.
+ *
+ * Note this means to allow quick query for single document. For batch documents, please refer
+ * to [[topicDistributions()]] to avoid overhead.
+ *
+ * @param document document to predict topic mixture distributions for
+ * @return topic mixture distribution for the document
+ */
+ @Since("2.0.0")
+ def topicDistribution(document: Vector): Vector = {
+ val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
+ if (document.numNonzeros == 0) {
+ Vectors.zeros(this.k)
+ } else {
+ val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
+ document,
+ expElogbeta,
+ this.docConcentration.toBreeze,
+ gammaShape,
+ this.k)
+ Vectors.dense(normalize(gamma, 1.0).toArray)
+ }
+ }
+
+ /**
* Java-friendly version of [[topicDistributions]]
*/
@Since("1.4.1")