From f51fd6fbb4d9822502f98b312251e317d757bc3a Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Fri, 31 Jul 2015 18:36:22 -0700 Subject: [SPARK-8936] [MLLIB] OnlineLDA document-topic Dirichlet hyperparameter optimization Adds `alpha` (document-topic Dirichlet parameter) hyperparameter optimization to `OnlineLDAOptimizer` following Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters. Also introduces a private `setSampleWithReplacement` to `OnlineLDAOptimizer` for unit testing purposes. Author: Feynman Liang Closes #7836 from feynmanliang/SPARK-8936-alpha-optimize and squashes the following commits: 4bef484 [Feynman Liang] Documentation improvements c3c6c1d [Feynman Liang] Fix docs 151e859 [Feynman Liang] Fix style fa77518 [Feynman Liang] Hyperparameter optimization --- .../apache/spark/mllib/clustering/LDASuite.scala | 34 ++++++++++++++++++++++ 1 file changed, 34 insertions(+) (limited to 'mllib/src/test') diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index f2b94707fd..fdc2554ab8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -400,6 +400,40 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("OnlineLDAOptimizer alpha hyperparameter optimization") { + val k = 2 + val docs = sc.parallelize(toyData) + val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) + .setGammaShape(100).setOptimzeAlpha(true).setSampleWithReplacement(false) + val lda = new LDA().setK(k) + .setDocConcentration(1D / k) + .setTopicConcentration(0.01) + .setMaxIterations(100) + .setOptimizer(op) + .setSeed(12345) + val ldaModel: LocalLDAModel = lda.run(docs).asInstanceOf[LocalLDAModel] + + /* Verify the results with gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha='auto', eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + print(lda.alpha) + > [ 0.42582646 0.43511073] + */ + + assert(ldaModel.docConcentration ~== Vectors.dense(0.42582646, 0.43511073) absTol 0.05) + } + test("model save/load") { // Test for LocalLDAModel. val localModel = new LocalLDAModel(tinyTopics, -- cgit v1.2.3