aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-11-10 16:20:10 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-10 16:20:10 -0800
commite281b87398f1298cc3df8e0409c7040acdddce03 (patch)
tree0b3c9361181479c47bc61e1000e103c831d52f72
parent1dde39d796bbf42336051a86bedf871c7fddd513 (diff)
downloadspark-e281b87398f1298cc3df8e0409c7040acdddce03.tar.gz
spark-e281b87398f1298cc3df8e0409c7040acdddce03.tar.bz2
spark-e281b87398f1298cc3df8e0409c7040acdddce03.zip
[SPARK-5565][ML] LDA wrapper for Pipelines API
This adds LDA to spark.ml, the Pipelines API. It follows the design doc in the JIRA: [https://issues.apache.org/jira/browse/SPARK-5565], with one major change: * I eliminated doc IDs. These are not necessary with DataFrames since the user can add an ID column as needed. Note: This will conflict with [https://github.com/apache/spark/pull/9484], but I'll try to merge [https://github.com/apache/spark/pull/9484] first and then rebase this PR. CC: hhbyyh feynmanliang If you have a chance to make a pass, that'd be really helpful--thanks! Now that I'm done traveling & this PR is almost ready, I'll see about reviewing other PRs critical for 1.6. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #9513 from jkbradley/lda-pipelines.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala701
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala29
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala221
3 files changed, 946 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
new file mode 100644
index 0000000000..f66233ed3d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -0,0 +1,701 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter}
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
+ EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
+ LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
+ OnlineLDAOptimizer => OldOnlineLDAOptimizer}
+import org.apache.spark.mllib.linalg.{VectorUDT, Vectors, Matrix, Vector}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{SQLContext, DataFrame, Row}
+import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf}
+import org.apache.spark.sql.types.StructType
+
+
+private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter
+ with HasSeed with HasCheckpointInterval {
+
+ /**
+ * Param for the number of topics (clusters) to infer. Must be > 1. Default: 10.
+ * @group param
+ */
+ @Since("1.6.0")
+ final val k = new IntParam(this, "k", "number of topics (clusters) to infer",
+ ParamValidators.gt(1))
+
+ /** @group getParam */
+ @Since("1.6.0")
+ def getK: Int = $(k)
+
+ /**
+ * Concentration parameter (commonly named "alpha") for the prior placed on documents'
+ * distributions over topics ("theta").
+ *
+ * This is the parameter to a Dirichlet distribution, where larger values mean more smoothing
+ * (more regularization).
+ *
+ * If not set by the user, then docConcentration is set automatically. If set to
+ * singleton vector [alpha], then alpha is replicated to a vector of length k in fitting.
+ * Otherwise, the [[docConcentration]] vector must be length k.
+ * (default = automatic)
+ *
+ * Optimizer-specific parameter settings:
+ * - EM
+ * - Currently only supports symmetric distributions, so all values in the vector should be
+ * the same.
+ * - Values should be > 1.0
+ * - default = uniformly (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows
+ * from Asuncion et al. (2009), who recommend a +1 adjustment for EM.
+ * - Online
+ * - Values should be >= 0
+ * - default = uniformly (1.0 / k), following the implementation from
+ * [[https://github.com/Blei-Lab/onlineldavb]].
+ * @group param
+ */
+ @Since("1.6.0")
+ final val docConcentration = new DoubleArrayParam(this, "docConcentration",
+ "Concentration parameter (commonly named \"alpha\") for the prior placed on documents'" +
+ " distributions over topics (\"theta\").", (alpha: Array[Double]) => alpha.forall(_ >= 0.0))
+
+ /** @group getParam */
+ @Since("1.6.0")
+ def getDocConcentration: Array[Double] = $(docConcentration)
+
+ /** Get docConcentration used by spark.mllib LDA */
+ protected def getOldDocConcentration: Vector = {
+ if (isSet(docConcentration)) {
+ Vectors.dense(getDocConcentration)
+ } else {
+ Vectors.dense(-1.0)
+ }
+ }
+
+ /**
+ * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
+ * distributions over terms.
+ *
+ * This is the parameter to a symmetric Dirichlet distribution.
+ *
+ * Note: The topics' distributions over terms are called "beta" in the original LDA paper
+ * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
+ *
+ * If not set by the user, then topicConcentration is set automatically.
+ * (default = automatic)
+ *
+ * Optimizer-specific parameter settings:
+ * - EM
+ * - Value should be > 1.0
+ * - default = 0.1 + 1, where 0.1 gives a small amount of smoothing and +1 follows
+ * Asuncion et al. (2009), who recommend a +1 adjustment for EM.
+ * - Online
+ * - Value should be >= 0
+ * - default = (1.0 / k), following the implementation from
+ * [[https://github.com/Blei-Lab/onlineldavb]].
+ * @group param
+ */
+ @Since("1.6.0")
+ final val topicConcentration = new DoubleParam(this, "topicConcentration",
+ "Concentration parameter (commonly named \"beta\" or \"eta\") for the prior placed on topic'" +
+ " distributions over terms.", ParamValidators.gtEq(0))
+
+ /** @group getParam */
+ @Since("1.6.0")
+ def getTopicConcentration: Double = $(topicConcentration)
+
+ /** Get topicConcentration used by spark.mllib LDA */
+ protected def getOldTopicConcentration: Double = {
+ if (isSet(topicConcentration)) {
+ getTopicConcentration
+ } else {
+ -1.0
+ }
+ }
+
+ /** Supported values for Param [[optimizer]]. */
+ @Since("1.6.0")
+ final val supportedOptimizers: Array[String] = Array("online", "em")
+
+ /**
+ * Optimizer or inference algorithm used to estimate the LDA model.
+ * Currently supported (case-insensitive):
+ * - "online": Online Variational Bayes (default)
+ * - "em": Expectation-Maximization
+ *
+ * For details, see the following papers:
+ * - Online LDA:
+ * Hoffman, Blei and Bach. "Online Learning for Latent Dirichlet Allocation."
+ * Neural Information Processing Systems, 2010.
+ * [[http://www.cs.columbia.edu/~blei/papers/HoffmanBleiBach2010b.pdf]]
+ * - EM:
+ * Asuncion et al. "On Smoothing and Inference for Topic Models."
+ * Uncertainty in Artificial Intelligence, 2009.
+ * [[http://arxiv.org/pdf/1205.2662.pdf]]
+ *
+ * @group param
+ */
+ @Since("1.6.0")
+ final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" +
+ " algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "),
+ (o: String) => ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase))
+
+ /** @group getParam */
+ @Since("1.6.0")
+ def getOptimizer: String = $(optimizer)
+
+ /**
+ * Output column with estimates of the topic mixture distribution for each document (often called
+ * "theta" in the literature). Returns a vector of zeros for an empty document.
+ *
+ * This uses a variational approximation following Hoffman et al. (2010), where the approximate
+ * distribution is called "gamma." Technically, this method returns this approximation "gamma"
+ * for each document.
+ * @group param
+ */
+ @Since("1.6.0")
+ final val topicDistributionCol = new Param[String](this, "topicDistribution", "Output column" +
+ " with estimates of the topic mixture distribution for each document (often called \"theta\"" +
+ " in the literature). Returns a vector of zeros for an empty document.")
+
+ setDefault(topicDistributionCol -> "topicDistribution")
+
+ /** @group getParam */
+ @Since("1.6.0")
+ def getTopicDistributionCol: String = $(topicDistributionCol)
+
+ /**
+ * A (positive) learning parameter that downweights early iterations. Larger values make early
+ * iterations count less.
+ * This is called "tau0" in the Online LDA paper (Hoffman et al., 2010)
+ * Default: 1024, following Hoffman et al.
+ * @group expertParam
+ */
+ @Since("1.6.0")
+ final val learningOffset = new DoubleParam(this, "learningOffset", "A (positive) learning" +
+ " parameter that downweights early iterations. Larger values make early iterations count less.",
+ ParamValidators.gt(0))
+
+ /** @group expertGetParam */
+ @Since("1.6.0")
+ def getLearningOffset: Double = $(learningOffset)
+
+ /**
+ * Learning rate, set as an exponential decay rate.
+ * This should be between (0.5, 1.0] to guarantee asymptotic convergence.
+ * This is called "kappa" in the Online LDA paper (Hoffman et al., 2010).
+ * Default: 0.51, based on Hoffman et al.
+ * @group expertParam
+ */
+ @Since("1.6.0")
+ final val learningDecay = new DoubleParam(this, "learningDecay", "Learning rate, set as an" +
+ " exponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic" +
+ " convergence.", ParamValidators.gt(0))
+
+ /** @group expertGetParam */
+ @Since("1.6.0")
+ def getLearningDecay: Double = $(learningDecay)
+
+ /**
+ * Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent,
+ * in range (0, 1].
+ *
+ * Note that this should be adjusted in synch with [[LDA.maxIter]]
+ * so the entire corpus is used. Specifically, set both so that
+ * maxIterations * miniBatchFraction >= 1.
+ *
+ * Note: This is the same as the `miniBatchFraction` parameter in
+ * [[org.apache.spark.mllib.clustering.OnlineLDAOptimizer]].
+ *
+ * Default: 0.05, i.e., 5% of total documents.
+ * @group param
+ */
+ @Since("1.6.0")
+ final val subsamplingRate = new DoubleParam(this, "subsamplingRate", "Fraction of the corpus" +
+ " to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].",
+ ParamValidators.inRange(0.0, 1.0, lowerInclusive = false, upperInclusive = true))
+
+ /** @group getParam */
+ @Since("1.6.0")
+ def getSubsamplingRate: Double = $(subsamplingRate)
+
+ /**
+ * Indicates whether the docConcentration (Dirichlet parameter for
+ * document-topic distribution) will be optimized during training.
+ * Setting this to true will make the model more expressive and fit the training data better.
+ * Default: false
+ * @group expertParam
+ */
+ @Since("1.6.0")
+ final val optimizeDocConcentration = new BooleanParam(this, "optimizeDocConcentration",
+ "Indicates whether the docConcentration (Dirichlet parameter for document-topic" +
+ " distribution) will be optimized during training.")
+
+ /** @group expertGetParam */
+ @Since("1.6.0")
+ def getOptimizeDocConcentration: Boolean = $(optimizeDocConcentration)
+
+ /**
+ * Validates and transforms the input schema.
+ * @param schema input schema
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
+ }
+
+ @Since("1.6.0")
+ override def validateParams(): Unit = {
+ if (isSet(docConcentration)) {
+ if (getDocConcentration.length != 1) {
+ require(getDocConcentration.length == getK, s"LDA docConcentration was of length" +
+ s" ${getDocConcentration.length}, but k = $getK. docConcentration must be an array of" +
+ s" length either 1 (scalar) or k (num topics).")
+ }
+ getOptimizer match {
+ case "online" =>
+ require(getDocConcentration.forall(_ >= 0),
+ "For Online LDA optimizer, docConcentration values must be >= 0. Found values: " +
+ getDocConcentration.mkString(","))
+ case "em" =>
+ require(getDocConcentration.forall(_ >= 0),
+ "For EM optimizer, docConcentration values must be >= 1. Found values: " +
+ getDocConcentration.mkString(","))
+ }
+ }
+ if (isSet(topicConcentration)) {
+ getOptimizer match {
+ case "online" =>
+ require(getTopicConcentration >= 0, s"For Online LDA optimizer, topicConcentration" +
+ s" must be >= 0. Found value: $getTopicConcentration")
+ case "em" =>
+ require(getTopicConcentration >= 0, s"For EM optimizer, topicConcentration" +
+ s" must be >= 1. Found value: $getTopicConcentration")
+ }
+ }
+ }
+
+ private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match {
+ case "online" =>
+ new OldOnlineLDAOptimizer()
+ .setTau0($(learningOffset))
+ .setKappa($(learningDecay))
+ .setMiniBatchFraction($(subsamplingRate))
+ .setOptimizeDocConcentration($(optimizeDocConcentration))
+ case "em" =>
+ new OldEMLDAOptimizer()
+ }
+}
+
+
+/**
+ * :: Experimental ::
+ * Model fitted by [[LDA]].
+ *
+ * @param vocabSize Vocabulary size (number of terms or terms in the vocabulary)
+ * @param oldLocalModel Underlying spark.mllib model.
+ * If this model was produced by Online LDA, then this is the
+ * only model representation.
+ * If this model was produced by EM, then this local
+ * representation may be built lazily.
+ * @param sqlContext Used to construct local DataFrames for returning query results
+ */
+@Since("1.6.0")
+@Experimental
+class LDAModel private[ml] (
+ @Since("1.6.0") override val uid: String,
+ @Since("1.6.0") val vocabSize: Int,
+ @Since("1.6.0") protected var oldLocalModel: Option[OldLocalLDAModel],
+ @Since("1.6.0") @transient protected val sqlContext: SQLContext)
+ extends Model[LDAModel] with LDAParams with Logging {
+
+ /** Returns underlying spark.mllib model */
+ @Since("1.6.0")
+ protected def getModel: OldLDAModel = oldLocalModel match {
+ case Some(m) => m
+ case None =>
+ // Should never happen.
+ throw new RuntimeException("LDAModel required local model format," +
+ " but the underlying model is missing.")
+ }
+
+ /**
+ * The features for LDA should be a [[Vector]] representing the word counts in a document.
+ * The vector should be of length vocabSize, with counts for each term (word).
+ * @group setParam
+ */
+ @Since("1.6.0")
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ @Since("1.6.0")
+ override def copy(extra: ParamMap): LDAModel = {
+ val copied = new LDAModel(uid, vocabSize, oldLocalModel, sqlContext)
+ copyValues(copied, extra).setParent(parent)
+ }
+
+ @Since("1.6.0")
+ override def transform(dataset: DataFrame): DataFrame = {
+ if ($(topicDistributionCol).nonEmpty) {
+ val t = udf(oldLocalModel.get.getTopicDistributionMethod(sqlContext.sparkContext))
+ dataset.withColumn($(topicDistributionCol), t(col($(featuresCol))))
+ } else {
+ logWarning("LDAModel.transform was called without any output columns. Set an output column" +
+ " such as topicDistributionCol to produce results.")
+ dataset
+ }
+ }
+
+ @Since("1.6.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ /**
+ * Value for [[docConcentration]] estimated from data.
+ * If Online LDA was used and [[optimizeDocConcentration]] was set to false,
+ * then this returns the fixed (given) value for the [[docConcentration]] parameter.
+ */
+ @Since("1.6.0")
+ def estimatedDocConcentration: Vector = getModel.docConcentration
+
+ /**
+ * Inferred topics, where each topic is represented by a distribution over terms.
+ * This is a matrix of size vocabSize x k, where each column is a topic.
+ * No guarantees are given about the ordering of the topics.
+ *
+ * WARNING: If this model is actually a [[DistributedLDAModel]] instance from EM,
+ * then this method could involve collecting a large amount of data to the driver
+ * (on the order of vocabSize x k).
+ */
+ @Since("1.6.0")
+ def topicsMatrix: Matrix = getModel.topicsMatrix
+
+ /** Indicates whether this instance is of type [[DistributedLDAModel]] */
+ @Since("1.6.0")
+ def isDistributed: Boolean = false
+
+ /**
+ * Calculates a lower bound on the log likelihood of the entire corpus.
+ *
+ * See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
+ *
+ * WARNING: If this model was learned via a [[DistributedLDAModel]], this involves collecting
+ * a large [[topicsMatrix]] to the driver. This implementation may be changed in the
+ * future.
+ *
+ * @param dataset test corpus to use for calculating log likelihood
+ * @return variational lower bound on the log likelihood of the entire corpus
+ */
+ @Since("1.6.0")
+ def logLikelihood(dataset: DataFrame): Double = oldLocalModel match {
+ case Some(m) =>
+ val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
+ m.logLikelihood(oldDataset)
+ case None =>
+ // Should never happen.
+ throw new RuntimeException("LocalLDAModel.logLikelihood was called," +
+ " but the underlying model is missing.")
+ }
+
+ /**
+ * Calculate an upper bound bound on perplexity. (Lower is better.)
+ * See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
+ *
+ * @param dataset test corpus to use for calculating perplexity
+ * @return Variational upper bound on log perplexity per token.
+ */
+ @Since("1.6.0")
+ def logPerplexity(dataset: DataFrame): Double = oldLocalModel match {
+ case Some(m) =>
+ val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
+ m.logPerplexity(oldDataset)
+ case None =>
+ // Should never happen.
+ throw new RuntimeException("LocalLDAModel.logPerplexity was called," +
+ " but the underlying model is missing.")
+ }
+
+ /**
+ * Return the topics described by their top-weighted terms.
+ *
+ * @param maxTermsPerTopic Maximum number of terms to collect for each topic.
+ * Default value of 10.
+ * @return Local DataFrame with one topic per Row, with columns:
+ * - "topic": IntegerType: topic index
+ * - "termIndices": ArrayType(IntegerType): term indices, sorted in order of decreasing
+ * term importance
+ * - "termWeights": ArrayType(DoubleType): corresponding sorted term weights
+ */
+ @Since("1.6.0")
+ def describeTopics(maxTermsPerTopic: Int): DataFrame = {
+ val topics = getModel.describeTopics(maxTermsPerTopic).zipWithIndex.map {
+ case ((termIndices, termWeights), topic) =>
+ (topic, termIndices.toSeq, termWeights.toSeq)
+ }
+ sqlContext.createDataFrame(topics).toDF("topic", "termIndices", "termWeights")
+ }
+
+ @Since("1.6.0")
+ def describeTopics(): DataFrame = describeTopics(10)
+}
+
+
+/**
+ * :: Experimental ::
+ *
+ * Distributed model fitted by [[LDA]] using Expectation-Maximization (EM).
+ *
+ * This model stores the inferred topics, the full training dataset, and the topic distribution
+ * for each training document.
+ */
+@Since("1.6.0")
+@Experimental
+class DistributedLDAModel private[ml] (
+ uid: String,
+ vocabSize: Int,
+ private val oldDistributedModel: OldDistributedLDAModel,
+ sqlContext: SQLContext)
+ extends LDAModel(uid, vocabSize, None, sqlContext) {
+
+ /**
+ * Convert this distributed model to a local representation. This discards info about the
+ * training dataset.
+ */
+ @Since("1.6.0")
+ def toLocal: LDAModel = {
+ if (oldLocalModel.isEmpty) {
+ oldLocalModel = Some(oldDistributedModel.toLocal)
+ }
+ new LDAModel(uid, vocabSize, oldLocalModel, sqlContext)
+ }
+
+ @Since("1.6.0")
+ override protected def getModel: OldLDAModel = oldDistributedModel
+
+ @Since("1.6.0")
+ override def copy(extra: ParamMap): DistributedLDAModel = {
+ val copied = new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext)
+ if (oldLocalModel.nonEmpty) copied.oldLocalModel = oldLocalModel
+ copyValues(copied, extra).setParent(parent)
+ copied
+ }
+
+ @Since("1.6.0")
+ override def topicsMatrix: Matrix = {
+ if (oldLocalModel.isEmpty) {
+ oldLocalModel = Some(oldDistributedModel.toLocal)
+ }
+ super.topicsMatrix
+ }
+
+ @Since("1.6.0")
+ override def isDistributed: Boolean = true
+
+ @Since("1.6.0")
+ override def logLikelihood(dataset: DataFrame): Double = {
+ if (oldLocalModel.isEmpty) {
+ oldLocalModel = Some(oldDistributedModel.toLocal)
+ }
+ super.logLikelihood(dataset)
+ }
+
+ @Since("1.6.0")
+ override def logPerplexity(dataset: DataFrame): Double = {
+ if (oldLocalModel.isEmpty) {
+ oldLocalModel = Some(oldDistributedModel.toLocal)
+ }
+ super.logPerplexity(dataset)
+ }
+
+ /**
+ * Log likelihood of the observed tokens in the training set,
+ * given the current parameter estimates:
+ * log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters)
+ *
+ * Notes:
+ * - This excludes the prior; for that, use [[logPrior]].
+ * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the
+ * hyperparameters.
+ * - This is computed from the topic distributions computed during training. If you call
+ * [[logLikelihood()]] on the same training dataset, the topic distributions will be computed
+ * again, possibly giving different results.
+ */
+ @Since("1.6.0")
+ lazy val trainingLogLikelihood: Double = oldDistributedModel.logLikelihood
+
+ /**
+ * Log probability of the current parameter estimate:
+ * log P(topics, topic distributions for docs | Dirichlet hyperparameters)
+ */
+ @Since("1.6.0")
+ lazy val logPrior: Double = oldDistributedModel.logPrior
+}
+
+
+/**
+ * :: Experimental ::
+ *
+ * Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
+ *
+ * Terminology:
+ * - "term" = "word": an element of the vocabulary
+ * - "token": instance of a term appearing in a document
+ * - "topic": multinomial distribution over terms representing some concept
+ * - "document": one piece of text, corresponding to one row in the input data
+ *
+ * References:
+ * - Original LDA paper (journal version):
+ * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
+ *
+ * Input data (featuresCol):
+ * LDA is given a collection of documents as input data, via the featuresCol parameter.
+ * Each document is specified as a [[Vector]] of length vocabSize, where each entry is the
+ * count for the corresponding term (word) in the document. Feature transformers such as
+ * [[org.apache.spark.ml.feature.Tokenizer]] and [[org.apache.spark.ml.feature.CountVectorizer]]
+ * can be useful for converting text to word count vectors.
+ *
+ * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation
+ * (Wikipedia)]]
+ */
+@Since("1.6.0")
+@Experimental
+class LDA @Since("1.6.0") (
+ @Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with LDAParams {
+
+ @Since("1.6.0")
+ def this() = this(Identifiable.randomUID("lda"))
+
+ setDefault(maxIter -> 20, k -> 10, optimizer -> "online", checkpointInterval -> 10,
+ learningOffset -> 1024, learningDecay -> 0.51, subsamplingRate -> 0.05,
+ optimizeDocConcentration -> true)
+
+ /**
+ * The features for LDA should be a [[Vector]] representing the word counts in a document.
+ * The vector should be of length vocabSize, with counts for each term (word).
+ * @group setParam
+ */
+ @Since("1.6.0")
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setK(value: Int): this.type = set(k, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setDocConcentration(value: Array[Double]): this.type = set(docConcentration, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setDocConcentration(value: Double): this.type = set(docConcentration, Array(value))
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setTopicConcentration(value: Double): this.type = set(topicConcentration, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setOptimizer(value: String): this.type = set(optimizer, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setTopicDistributionCol(value: String): this.type = set(topicDistributionCol, value)
+
+ /** @group expertSetParam */
+ @Since("1.6.0")
+ def setLearningOffset(value: Double): this.type = set(learningOffset, value)
+
+ /** @group expertSetParam */
+ @Since("1.6.0")
+ def setLearningDecay(value: Double): this.type = set(learningDecay, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
+
+ /** @group expertSetParam */
+ @Since("1.6.0")
+ def setOptimizeDocConcentration(value: Boolean): this.type = set(optimizeDocConcentration, value)
+
+ @Since("1.6.0")
+ override def copy(extra: ParamMap): LDA = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def fit(dataset: DataFrame): LDAModel = {
+ transformSchema(dataset.schema, logging = true)
+ val oldLDA = new OldLDA()
+ .setK($(k))
+ .setDocConcentration(getOldDocConcentration)
+ .setTopicConcentration(getOldTopicConcentration)
+ .setMaxIterations($(maxIter))
+ .setSeed($(seed))
+ .setCheckpointInterval($(checkpointInterval))
+ .setOptimizer(getOldOptimizer)
+ // TODO: persist here, or in old LDA?
+ val oldData = LDA.getOldDataset(dataset, $(featuresCol))
+ val oldModel = oldLDA.run(oldData)
+ val newModel = oldModel match {
+ case m: OldLocalLDAModel =>
+ new LDAModel(uid, m.vocabSize, Some(m), dataset.sqlContext)
+ case m: OldDistributedLDAModel =>
+ new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext)
+ }
+ copyValues(newModel).setParent(this)
+ }
+
+ @Since("1.6.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+}
+
+
+private[clustering] object LDA {
+
+ /** Get dataset for spark.mllib LDA */
+ def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = {
+ dataset
+ .withColumn("docId", monotonicallyIncreasingId())
+ .select("docId", featuresCol)
+ .map { case Row(docId: Long, features: Vector) =>
+ (docId, features)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 31d8a9fdea..cd520f09bd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -183,8 +183,7 @@ abstract class LDAModel private[clustering] extends Saveable {
/**
* Local LDA model.
* This model stores only the inferred topics.
- * It may be used for computing topics for new documents, but it may give less accurate answers
- * than the [[DistributedLDAModel]].
+ *
* @param topics Inferred topics (vocabSize x k matrix).
*/
@Since("1.3.0")
@@ -353,7 +352,7 @@ class LocalLDAModel private[clustering] (
documents.map { case (id: Long, termCounts: Vector) =>
if (termCounts.numNonzeros == 0) {
- (id, Vectors.zeros(k))
+ (id, Vectors.zeros(k))
} else {
val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
termCounts,
@@ -366,6 +365,28 @@ class LocalLDAModel private[clustering] (
}
}
+ /** Get a method usable as a UDF for [[topicDistributions()]] */
+ private[spark] def getTopicDistributionMethod(sc: SparkContext): Vector => Vector = {
+ val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
+ val expElogbetaBc = sc.broadcast(expElogbeta)
+ val docConcentrationBrz = this.docConcentration.toBreeze
+ val gammaShape = this.gammaShape
+ val k = this.k
+
+ (termCounts: Vector) =>
+ if (termCounts.numNonzeros == 0) {
+ Vectors.zeros(k)
+ } else {
+ val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
+ termCounts,
+ expElogbetaBc.value,
+ docConcentrationBrz,
+ gammaShape,
+ k)
+ Vectors.dense(normalize(gamma, 1.0).toArray)
+ }
+ }
+
/**
* Java-friendly version of [[topicDistributions]]
*/
@@ -477,8 +498,6 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
/**
* Distributed LDA model.
* This model stores the inferred topics, the full training dataset, and the topic distributions.
- * When computing topics for new documents, it may give more accurate answers
- * than the [[LocalLDAModel]].
*/
@Since("1.3.0")
class DistributedLDAModel private[clustering] (
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
new file mode 100644
index 0000000000..edb927495e
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+
+object LDASuite {
+ def generateLDAData(
+ sql: SQLContext,
+ rows: Int,
+ k: Int,
+ vocabSize: Int): DataFrame = {
+ val avgWC = 1 // average instances of each word in a doc
+ val sc = sql.sparkContext
+ val rng = new java.util.Random()
+ rng.setSeed(1)
+ val rdd = sc.parallelize(1 to rows).map { i =>
+ Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble))
+ }.map(v => new TestRow(v))
+ sql.createDataFrame(rdd)
+ }
+}
+
+
+class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ val k: Int = 5
+ val vocabSize: Int = 30
+ @transient var dataset: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ dataset = LDASuite.generateLDAData(sqlContext, 50, k, vocabSize)
+ }
+
+ test("default parameters") {
+ val lda = new LDA()
+
+ assert(lda.getFeaturesCol === "features")
+ assert(lda.getMaxIter === 20)
+ assert(lda.isDefined(lda.seed))
+ assert(lda.getCheckpointInterval === 10)
+ assert(lda.getK === 10)
+ assert(!lda.isSet(lda.docConcentration))
+ assert(!lda.isSet(lda.topicConcentration))
+ assert(lda.getOptimizer === "online")
+ assert(lda.getLearningDecay === 0.51)
+ assert(lda.getLearningOffset === 1024)
+ assert(lda.getSubsamplingRate === 0.05)
+ assert(lda.getOptimizeDocConcentration)
+ assert(lda.getTopicDistributionCol === "topicDistribution")
+ }
+
+ test("set parameters") {
+ val lda = new LDA()
+ .setFeaturesCol("test_feature")
+ .setMaxIter(33)
+ .setSeed(123)
+ .setCheckpointInterval(7)
+ .setK(9)
+ .setTopicConcentration(0.56)
+ .setTopicDistributionCol("myOutput")
+
+ assert(lda.getFeaturesCol === "test_feature")
+ assert(lda.getMaxIter === 33)
+ assert(lda.getSeed === 123)
+ assert(lda.getCheckpointInterval === 7)
+ assert(lda.getK === 9)
+ assert(lda.getTopicConcentration === 0.56)
+ assert(lda.getTopicDistributionCol === "myOutput")
+
+
+ // setOptimizer
+ lda.setOptimizer("em")
+ assert(lda.getOptimizer === "em")
+ lda.setOptimizer("online")
+ assert(lda.getOptimizer === "online")
+ lda.setLearningDecay(0.53)
+ assert(lda.getLearningDecay === 0.53)
+ lda.setLearningOffset(1027)
+ assert(lda.getLearningOffset === 1027)
+ lda.setSubsamplingRate(0.06)
+ assert(lda.getSubsamplingRate === 0.06)
+ lda.setOptimizeDocConcentration(false)
+ assert(!lda.getOptimizeDocConcentration)
+ }
+
+ test("parameters validation") {
+ val lda = new LDA()
+
+ // misc Params
+ intercept[IllegalArgumentException] {
+ new LDA().setK(1)
+ }
+ intercept[IllegalArgumentException] {
+ new LDA().setOptimizer("no_such_optimizer")
+ }
+ intercept[IllegalArgumentException] {
+ new LDA().setDocConcentration(-1.1)
+ }
+ intercept[IllegalArgumentException] {
+ new LDA().setTopicConcentration(-1.1)
+ }
+
+ // validateParams()
+ lda.validateParams()
+ lda.setDocConcentration(1.1)
+ lda.validateParams()
+ lda.setDocConcentration(Range(0, lda.getK).map(_ + 2.0).toArray)
+ lda.validateParams()
+ lda.setDocConcentration(Range(0, lda.getK - 1).map(_ + 2.0).toArray)
+ withClue("LDA docConcentration validity check failed for bad array length") {
+ intercept[IllegalArgumentException] {
+ lda.validateParams()
+ }
+ }
+
+ // Online LDA
+ intercept[IllegalArgumentException] {
+ new LDA().setLearningOffset(0)
+ }
+ intercept[IllegalArgumentException] {
+ new LDA().setLearningDecay(0)
+ }
+ intercept[IllegalArgumentException] {
+ new LDA().setSubsamplingRate(0)
+ }
+ intercept[IllegalArgumentException] {
+ new LDA().setSubsamplingRate(1.1)
+ }
+ }
+
+ test("fit & transform with Online LDA") {
+ val lda = new LDA().setK(k).setSeed(1).setOptimizer("online").setMaxIter(2)
+ val model = lda.fit(dataset)
+
+ MLTestingUtils.checkCopy(model)
+
+ assert(!model.isInstanceOf[DistributedLDAModel])
+ assert(model.vocabSize === vocabSize)
+ assert(model.estimatedDocConcentration.size === k)
+ assert(model.topicsMatrix.numRows === vocabSize)
+ assert(model.topicsMatrix.numCols === k)
+ assert(!model.isDistributed)
+
+ // transform()
+ val transformed = model.transform(dataset)
+ val expectedColumns = Array("features", lda.getTopicDistributionCol)
+ expectedColumns.foreach { column =>
+ assert(transformed.columns.contains(column))
+ }
+ transformed.select(lda.getTopicDistributionCol).collect().foreach { r =>
+ val topicDistribution = r.getAs[Vector](0)
+ assert(topicDistribution.size === k)
+ assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0))
+ }
+
+ // logLikelihood, logPerplexity
+ val ll = model.logLikelihood(dataset)
+ assert(ll <= 0.0 && ll != Double.NegativeInfinity)
+ val lp = model.logPerplexity(dataset)
+ assert(lp >= 0.0 && lp != Double.PositiveInfinity)
+
+ // describeTopics
+ val topics = model.describeTopics(3)
+ assert(topics.count() === k)
+ assert(topics.select("topic").map(_.getInt(0)).collect().toSet === Range(0, k).toSet)
+ topics.select("termIndices").collect().foreach { case r: Row =>
+ val termIndices = r.getAs[Seq[Int]](0)
+ assert(termIndices.length === 3 && termIndices.toSet.size === 3)
+ }
+ topics.select("termWeights").collect().foreach { case r: Row =>
+ val termWeights = r.getAs[Seq[Double]](0)
+ assert(termWeights.length === 3 && termWeights.forall(w => w >= 0.0 && w <= 1.0))
+ }
+ }
+
+ test("fit & transform with EM LDA") {
+ val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2)
+ val model_ = lda.fit(dataset)
+
+ MLTestingUtils.checkCopy(model_)
+
+ assert(model_.isInstanceOf[DistributedLDAModel])
+ val model = model_.asInstanceOf[DistributedLDAModel]
+ assert(model.vocabSize === vocabSize)
+ assert(model.estimatedDocConcentration.size === k)
+ assert(model.topicsMatrix.numRows === vocabSize)
+ assert(model.topicsMatrix.numCols === k)
+ assert(model.isDistributed)
+
+ val localModel = model.toLocal
+ assert(!localModel.isInstanceOf[DistributedLDAModel])
+
+ // training logLikelihood, logPrior
+ val ll = model.trainingLogLikelihood
+ assert(ll <= 0.0 && ll != Double.NegativeInfinity)
+ val lp = model.logPrior
+ assert(lp <= 0.0 && lp != Double.NegativeInfinity)
+ }
+}