aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala26
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala15
2 files changed, 38 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 2fce3ff641..b30ecb8020 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -388,6 +388,32 @@ class LocalLDAModel private[spark] (
}
/**
+ * Predicts the topic mixture distribution for a document (often called "theta" in the
+ * literature). Returns a vector of zeros for an empty document.
+ *
+ * Note this means to allow quick query for single document. For batch documents, please refer
+ * to [[topicDistributions()]] to avoid overhead.
+ *
+ * @param document document to predict topic mixture distributions for
+ * @return topic mixture distribution for the document
+ */
+ @Since("2.0.0")
+ def topicDistribution(document: Vector): Vector = {
+ val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
+ if (document.numNonzeros == 0) {
+ Vectors.zeros(this.k)
+ } else {
+ val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
+ document,
+ expElogbeta,
+ this.docConcentration.toBreeze,
+ gammaShape,
+ this.k)
+ Vectors.dense(normalize(gamma, 1.0).toArray)
+ }
+ }
+
+ /**
* Java-friendly version of [[topicDistributions]]
*/
@Since("1.4.1")
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index faef60e084..ea23196d2c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -366,7 +366,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
(0, 0.99504), (1, 0.99504),
(1, 0.99504), (1, 0.99504))
- val actualPredictions = ldaModel.topicDistributions(docs).map { case (id, topics) =>
+ val actualPredictions = ldaModel.topicDistributions(docs).cache()
+ val topTopics = actualPredictions.map { case (id, topics) =>
// convert results to expectedPredictions format, which only has highest probability topic
val topicsBz = topics.toBreeze.toDenseVector
(id, (argmax(topicsBz), max(topicsBz)))
@@ -374,9 +375,17 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
.values
.collect()
- expectedPredictions.zip(actualPredictions).forall { case (expected, actual) =>
- expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D)
+ expectedPredictions.zip(topTopics).foreach { case (expected, actual) =>
+ assert(expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D))
}
+
+ docs.collect()
+ .map(doc => ldaModel.topicDistribution(doc._2))
+ .zip(actualPredictions.map(_._2).collect())
+ .foreach { case (single, batch) =>
+ assert(single ~== batch relTol 1E-3D)
+ }
+ actualPredictions.unpersist()
}
test("OnlineLDAOptimizer with asymmetric prior") {