diff options
author | Yuhao Yang <hhbyyh@gmail.com> | 2016-01-11 14:55:44 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-01-11 14:55:44 -0800 |
commit | bbea88852ce6a3127d071ca40dbca2d042f9fbcf (patch) | |
tree | 7aab3ab1a41d5589c6ac7bbf2b3c3b7d7c647512 /mllib/src/main | |
parent | 4f8eefa36bb90812aac61ac7a762c9452de666bf (diff) | |
download | spark-bbea88852ce6a3127d071ca40dbca2d042f9fbcf.tar.gz spark-bbea88852ce6a3127d071ca40dbca2d042f9fbcf.tar.bz2 spark-bbea88852ce6a3127d071ca40dbca2d042f9fbcf.zip |
[SPARK-10809][MLLIB] Single-document topicDistributions method for LocalLDAModel
jira: https://issues.apache.org/jira/browse/SPARK-10809
We could provide a single-document topicDistributions method for LocalLDAModel to allow for quick queries which avoid RDD operations. Currently, the user must use an RDD of documents.
add some missing assert too.
Author: Yuhao Yang <hhbyyh@gmail.com>
Closes #9484 from hhbyyh/ldaTopicPre.
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 2fce3ff641..b30ecb8020 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -388,6 +388,32 @@ class LocalLDAModel private[spark] ( } /** + * Predicts the topic mixture distribution for a document (often called "theta" in the + * literature). Returns a vector of zeros for an empty document. + * + * Note this means to allow quick query for single document. For batch documents, please refer + * to [[topicDistributions()]] to avoid overhead. + * + * @param document document to predict topic mixture distributions for + * @return topic mixture distribution for the document + */ + @Since("2.0.0") + def topicDistribution(document: Vector): Vector = { + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + if (document.numNonzeros == 0) { + Vectors.zeros(this.k) + } else { + val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( + document, + expElogbeta, + this.docConcentration.toBreeze, + gammaShape, + this.k) + Vectors.dense(normalize(gamma, 1.0).toArray) + } + } + + /** * Java-friendly version of [[topicDistributions]] */ @Since("1.4.1") |