aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorFeynman Liang <fliang@databricks.com>2015-07-31 18:36:22 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-31 18:36:22 -0700
commitf51fd6fbb4d9822502f98b312251e317d757bc3a (patch)
treed69b292785aaeaf1537c2937c24653ee32e1c594 /mllib
parent4d5a6e7b60b315968973e2298eeee5eb174ec721 (diff)
downloadspark-f51fd6fbb4d9822502f98b312251e317d757bc3a.tar.gz
spark-f51fd6fbb4d9822502f98b312251e317d757bc3a.tar.bz2
spark-f51fd6fbb4d9822502f98b312251e317d757bc3a.zip
[SPARK-8936] [MLLIB] OnlineLDA document-topic Dirichlet hyperparameter optimization
Adds `alpha` (document-topic Dirichlet parameter) hyperparameter optimization to `OnlineLDAOptimizer` following Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters. Also introduces a private `setSampleWithReplacement` to `OnlineLDAOptimizer` for unit testing purposes. Author: Feynman Liang <fliang@databricks.com> Closes #7836 from feynmanliang/SPARK-8936-alpha-optimize and squashes the following commits: 4bef484 [Feynman Liang] Documentation improvements c3c6c1d [Feynman Liang] Fix docs 151e859 [Feynman Liang] Fix style fa77518 [Feynman Liang] Hyperparameter optimization
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala75
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala34
2 files changed, 99 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index d6f8b29a43..b0e14cb829 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -19,8 +19,8 @@ package org.apache.spark.mllib.clustering
import java.util.Random
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
-import breeze.numerics.{abs, exp}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, sum}
+import breeze.numerics.{trigamma, abs, exp}
import breeze.stats.distributions.{Gamma, RandBasis}
import org.apache.spark.annotation.DeveloperApi
@@ -239,22 +239,26 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
/** alias for docConcentration */
private var alpha: Vector = Vectors.dense(0)
- /** (private[clustering] for debugging) Get docConcentration */
+ /** (for debugging) Get docConcentration */
private[clustering] def getAlpha: Vector = alpha
/** alias for topicConcentration */
private var eta: Double = 0
- /** (private[clustering] for debugging) Get topicConcentration */
+ /** (for debugging) Get topicConcentration */
private[clustering] def getEta: Double = eta
private var randomGenerator: java.util.Random = null
+ /** (for debugging) Whether to sample mini-batches with replacement. (default = true) */
+ private var sampleWithReplacement: Boolean = true
+
// Online LDA specific parameters
// Learning rate is: (tau0 + t)^{-kappa}
private var tau0: Double = 1024
private var kappa: Double = 0.51
private var miniBatchFraction: Double = 0.05
+ private var optimizeAlpha: Boolean = false
// internal data structure
private var docs: RDD[(Long, Vector)] = null
@@ -262,7 +266,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
/** Dirichlet parameter for the posterior over topics */
private var lambda: BDM[Double] = null
- /** (private[clustering] for debugging) Get parameter for topics */
+ /** (for debugging) Get parameter for topics */
private[clustering] def getLambda: BDM[Double] = lambda
/** Current iteration (count of invocations of [[next()]]) */
@@ -325,7 +329,22 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
}
/**
- * (private[clustering])
+ * Optimize alpha, indicates whether alpha (Dirichlet parameter for document-topic distribution)
+ * will be optimized during training.
+ */
+ def getOptimzeAlpha: Boolean = this.optimizeAlpha
+
+ /**
+ * Sets whether to optimize alpha parameter during training.
+ *
+ * Default: false
+ */
+ def setOptimzeAlpha(optimizeAlpha: Boolean): this.type = {
+ this.optimizeAlpha = optimizeAlpha
+ this
+ }
+
+ /**
* Set the Dirichlet parameter for the posterior over topics.
* This is only used for testing now. In the future, it can help support training stop/resume.
*/
@@ -335,7 +354,6 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
}
/**
- * (private[clustering])
* Used for random initialization of the variational parameters.
* Larger value produces values closer to 1.0.
* This is only used for testing currently.
@@ -345,6 +363,15 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
this
}
+ /**
+ * Sets whether to sample mini-batches with or without replacement. (default = true)
+ * This is only used for testing currently.
+ */
+ private[clustering] def setSampleWithReplacement(replace: Boolean): this.type = {
+ this.sampleWithReplacement = replace
+ this
+ }
+
override private[clustering] def initialize(
docs: RDD[(Long, Vector)],
lda: LDA): OnlineLDAOptimizer = {
@@ -376,7 +403,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
}
override private[clustering] def next(): OnlineLDAOptimizer = {
- val batch = docs.sample(withReplacement = true, miniBatchFraction, randomGenerator.nextLong())
+ val batch = docs.sample(withReplacement = sampleWithReplacement, miniBatchFraction,
+ randomGenerator.nextLong())
if (batch.isEmpty()) return this
submitMiniBatch(batch)
}
@@ -418,6 +446,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
// Note that this is an optimization to avoid batch.count
updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt)
+ if (optimizeAlpha) updateAlpha(gammat)
this
}
@@ -433,13 +462,39 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
weight * (stat * (corpusSize.toDouble / batchSize.toDouble) + eta)
}
- /** Calculates learning rate rho, which decays as a function of [[iteration]] */
+ /**
+ * Update alpha based on `gammat`, the inferred topic distributions for documents in the
+ * current mini-batch. Uses Newton-Rhapson method.
+ * @see Section 3.3, Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters
+ * (http://jonathan-huang.org/research/dirichlet/dirichlet.pdf)
+ */
+ private def updateAlpha(gammat: BDM[Double]): Unit = {
+ val weight = rho()
+ val N = gammat.rows.toDouble
+ val alpha = this.alpha.toBreeze.toDenseVector
+ val logphat: BDM[Double] = sum(LDAUtils.dirichletExpectation(gammat)(::, breeze.linalg.*)) / N
+ val gradf = N * (-LDAUtils.dirichletExpectation(alpha) + logphat.toDenseVector)
+
+ val c = N * trigamma(sum(alpha))
+ val q = -N * trigamma(alpha)
+ val b = sum(gradf / q) / (1D / c + sum(1D / q))
+
+ val dalpha = -(gradf - b) / q
+
+ if (all((weight * dalpha + alpha) :> 0D)) {
+ alpha :+= weight * dalpha
+ this.alpha = Vectors.dense(alpha.toArray)
+ }
+ }
+
+
+ /** Calculate learning rate rho for the current [[iteration]]. */
private def rho(): Double = {
math.pow(getTau0 + this.iteration, -getKappa)
}
/**
- * Get a random matrix to initialize lambda
+ * Get a random matrix to initialize lambda.
*/
private def getGammaMatrix(row: Int, col: Int): BDM[Double] = {
val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index f2b94707fd..fdc2554ab8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -400,6 +400,40 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("OnlineLDAOptimizer alpha hyperparameter optimization") {
+ val k = 2
+ val docs = sc.parallelize(toyData)
+ val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
+ .setGammaShape(100).setOptimzeAlpha(true).setSampleWithReplacement(false)
+ val lda = new LDA().setK(k)
+ .setDocConcentration(1D / k)
+ .setTopicConcentration(0.01)
+ .setMaxIterations(100)
+ .setOptimizer(op)
+ .setSeed(12345)
+ val ldaModel: LocalLDAModel = lda.run(docs).asInstanceOf[LocalLDAModel]
+
+ /* Verify the results with gensim:
+ import numpy as np
+ from gensim import models
+ corpus = [
+ [(0, 1.0), (1, 1.0)],
+ [(1, 1.0), (2, 1.0)],
+ [(0, 1.0), (2, 1.0)],
+ [(3, 1.0), (4, 1.0)],
+ [(3, 1.0), (5, 1.0)],
+ [(4, 1.0), (5, 1.0)]]
+ np.random.seed(2345)
+ lda = models.ldamodel.LdaModel(
+ corpus=corpus, alpha='auto', eta=0.01, num_topics=2, update_every=0, passes=100,
+ decay=0.51, offset=1024)
+ print(lda.alpha)
+ > [ 0.42582646 0.43511073]
+ */
+
+ assert(ldaModel.docConcentration ~== Vectors.dense(0.42582646, 0.43511073) absTol 0.05)
+ }
+
test("model save/load") {
// Test for LocalLDAModel.
val localModel = new LocalLDAModel(tinyTopics,