aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
blob: 27b4004927aaa742eb49d2c5338c26be495b2dec (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.mllib.clustering

import breeze.linalg.{argmax, argtopk, normalize, sum, DenseMatrix => BDM, DenseVector => BDV}
import breeze.numerics.{exp, lgamma}
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId}
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.util.BoundedPriorityQueue

/**
 * Latent Dirichlet Allocation (LDA) model.
 *
 * This abstraction permits for different underlying representations,
 * including local and distributed data structures.
 */
@Since("1.3.0")
abstract class LDAModel private[clustering] extends Saveable {

  /** Number of topics */
  @Since("1.3.0")
  def k: Int

  /** Vocabulary size (number of terms or terms in the vocabulary) */
  @Since("1.3.0")
  def vocabSize: Int

  /**
   * Concentration parameter (commonly named "alpha") for the prior placed on documents'
   * distributions over topics ("theta").
   *
   * This is the parameter to a Dirichlet distribution.
   */
  @Since("1.5.0")
  def docConcentration: Vector

  /**
   * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
   * distributions over terms.
   *
   * This is the parameter to a symmetric Dirichlet distribution.
   *
   * Note: The topics' distributions over terms are called "beta" in the original LDA paper
   * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
   */
  @Since("1.5.0")
  def topicConcentration: Double

  /**
  * Shape parameter for random initialization of variational parameter gamma.
  * Used for variational inference for perplexity and other test-time computations.
  */
  protected def gammaShape: Double

  /**
   * Inferred topics, where each topic is represented by a distribution over terms.
   * This is a matrix of size vocabSize x k, where each column is a topic.
   * No guarantees are given about the ordering of the topics.
   */
  @Since("1.3.0")
  def topicsMatrix: Matrix

  /**
   * Return the topics described by weighted terms.
   *
   * @param maxTermsPerTopic  Maximum number of terms to collect for each topic.
   * @return  Array over topics.  Each topic is represented as a pair of matching arrays:
   *          (term indices, term weights in topic).
   *          Each topic's terms are sorted in order of decreasing weight.
   */
  @Since("1.3.0")
  def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])]

  /**
   * Return the topics described by weighted terms.
   *
   * WARNING: If vocabSize and k are large, this can return a large object!
   *
   * @return  Array over topics.  Each topic is represented as a pair of matching arrays:
   *          (term indices, term weights in topic).
   *          Each topic's terms are sorted in order of decreasing weight.
   */
  @Since("1.3.0")
  def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize)

  /* TODO (once LDA can be trained with Strings or given a dictionary)
   * Return the topics described by weighted terms.
   *
   * This is similar to [[describeTopics()]] but returns String values for terms.
   * If this model was trained using Strings or was given a dictionary, then this method returns
   * terms as text.  Otherwise, this method returns terms as term indices.
   *
   * This limits the number of terms per topic.
   * This is approximate; it may not return exactly the top-weighted terms for each topic.
   * To get a more precise set of top terms, increase maxTermsPerTopic.
   *
   * @param maxTermsPerTopic  Maximum number of terms to collect for each topic.
   * @return  Array over topics.  Each topic is represented as a pair of matching arrays:
   *          (terms, term weights in topic) where terms are either the actual term text
   *          (if available) or the term indices.
   *          Each topic's terms are sorted in order of decreasing weight.
   */
  // def describeTopicsAsStrings(maxTermsPerTopic: Int): Array[(Array[Double], Array[String])]

  /* TODO (once LDA can be trained with Strings or given a dictionary)
   * Return the topics described by weighted terms.
   *
   * This is similar to [[describeTopics()]] but returns String values for terms.
   * If this model was trained using Strings or was given a dictionary, then this method returns
   * terms as text.  Otherwise, this method returns terms as term indices.
   *
   * WARNING: If vocabSize and k are large, this can return a large object!
   *
   * @return  Array over topics.  Each topic is represented as a pair of matching arrays:
   *          (terms, term weights in topic) where terms are either the actual term text
   *          (if available) or the term indices.
   *          Each topic's terms are sorted in order of decreasing weight.
   */
  // def describeTopicsAsStrings(): Array[(Array[Double], Array[String])] =
  //  describeTopicsAsStrings(vocabSize)

  /* TODO
   * Compute the log likelihood of the observed tokens, given the current parameter estimates:
   *  log P(docs | topics, topic distributions for docs, alpha, eta)
   *
   * Note:
   *  - This excludes the prior.
   *  - Even with the prior, this is NOT the same as the data log likelihood given the
   *    hyperparameters.
   *
   * @param documents  RDD of documents, which are term (word) count vectors paired with IDs.
   *                   The term count vectors are "bags of words" with a fixed-size vocabulary
   *                   (where the vocabulary size is the length of the vector).
   *                   This must use the same vocabulary (ordering of term counts) as in training.
   *                   Document IDs must be unique and >= 0.
   * @return  Estimated log likelihood of the data under this model
   */
  // def logLikelihood(documents: RDD[(Long, Vector)]): Double

  /* TODO
   * Compute the estimated topic distribution for each document.
   * This is often called 'theta' in the literature.
   *
   * @param documents  RDD of documents, which are term (word) count vectors paired with IDs.
   *                   The term count vectors are "bags of words" with a fixed-size vocabulary
   *                   (where the vocabulary size is the length of the vector).
   *                   This must use the same vocabulary (ordering of term counts) as in training.
   *                   Document IDs must be unique and >= 0.
   * @return  Estimated topic distribution for each document.
   *          The returned RDD may be zipped with the given RDD, where each returned vector
   *          is a multinomial distribution over topics.
   */
  // def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)]

}

/**
 * Local LDA model.
 * This model stores only the inferred topics.
 *
 * @param topics Inferred topics (vocabSize x k matrix).
 */
@Since("1.3.0")
class LocalLDAModel private[spark] (
    @Since("1.3.0") val topics: Matrix,
    @Since("1.5.0") override val docConcentration: Vector,
    @Since("1.5.0") override val topicConcentration: Double,
    override protected[spark] val gammaShape: Double = 100)
  extends LDAModel with Serializable {

  @Since("1.3.0")
  override def k: Int = topics.numCols

  @Since("1.3.0")
  override def vocabSize: Int = topics.numRows

  @Since("1.3.0")
  override def topicsMatrix: Matrix = topics

  @Since("1.3.0")
  override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = {
    val brzTopics = topics.toBreeze.toDenseMatrix
    Range(0, k).map { topicIndex =>
      val topic = normalize(brzTopics(::, topicIndex), 1.0)
      val (termWeights, terms) =
        topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic).unzip
      (terms.toArray, termWeights.toArray)
    }.toArray
  }

  override protected def formatVersion = "1.0"

  @Since("1.5.0")
  override def save(sc: SparkContext, path: String): Unit = {
    LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration,
      gammaShape)
  }

  // TODO: declare in LDAModel and override once implemented in DistributedLDAModel
  /**
   * Calculates a lower bound on the log likelihood of the entire corpus.
   *
   * See Equation (16) in original Online LDA paper.
   *
   * @param documents test corpus to use for calculating log likelihood
   * @return variational lower bound on the log likelihood of the entire corpus
   */
  @Since("1.5.0")
  def logLikelihood(documents: RDD[(Long, Vector)]): Double = logLikelihoodBound(documents,
    docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k,
    vocabSize)

  /**
   * Java-friendly version of [[logLikelihood]]
   */
  @Since("1.5.0")
  def logLikelihood(documents: JavaPairRDD[java.lang.Long, Vector]): Double = {
    logLikelihood(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
  }

  /**
   * Calculate an upper bound bound on perplexity.  (Lower is better.)
   * See Equation (16) in original Online LDA paper.
   *
   * @param documents test corpus to use for calculating perplexity
   * @return Variational upper bound on log perplexity per token.
   */
  @Since("1.5.0")
  def logPerplexity(documents: RDD[(Long, Vector)]): Double = {
    val corpusTokenCount = documents
      .map { case (_, termCounts) => termCounts.toArray.sum }
      .sum()
    -logLikelihood(documents) / corpusTokenCount
  }

  /** Java-friendly version of [[logPerplexity]] */
  @Since("1.5.0")
  def logPerplexity(documents: JavaPairRDD[java.lang.Long, Vector]): Double = {
    logPerplexity(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
  }

  /**
   * Estimate the variational likelihood bound of from `documents`:
   *    log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)]
   * This bound is derived by decomposing the LDA model to:
   *    log p(documents) = E_q[log p(documents)] - E_q[log q(documents)] + D(q|p)
   * and noting that the KL-divergence D(q|p) >= 0.
   *
   * See Equation (16) in original Online LDA paper, as well as Appendix A.3 in the JMLR version of
   * the original LDA paper.
   * @param documents a subset of the test corpus
   * @param alpha document-topic Dirichlet prior parameters
   * @param eta topic-word Dirichlet prior parameter
   * @param lambda parameters for variational q(beta | lambda) topic-word distributions
   * @param gammaShape shape parameter for random initialization of variational q(theta | gamma)
   *                   topic mixture distributions
   * @param k number of topics
   * @param vocabSize number of unique terms in the entire test corpus
   */
  private def logLikelihoodBound(
      documents: RDD[(Long, Vector)],
      alpha: Vector,
      eta: Double,
      lambda: BDM[Double],
      gammaShape: Double,
      k: Int,
      vocabSize: Long): Double = {
    val brzAlpha = alpha.toBreeze.toDenseVector
    // transpose because dirichletExpectation normalizes by row and we need to normalize
    // by topic (columns of lambda)
    val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t
    val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta)

    // Sum bound components for each document:
    //  component for prob(tokens) + component for prob(document-topic distribution)
    val corpusPart =
      documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) =>
        val localElogbeta = ElogbetaBc.value
        var docBound = 0.0D
        val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference(
          termCounts, exp(localElogbeta), brzAlpha, gammaShape, k)
        val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad)

        // E[log p(doc | theta, beta)]
        termCounts.foreachActive { case (idx, count) =>
          docBound += count * LDAUtils.logSumExp(Elogthetad + localElogbeta(idx, ::).t)
        }
        // E[log p(theta | alpha) - log q(theta | gamma)]
        docBound += sum((brzAlpha - gammad) :* Elogthetad)
        docBound += sum(lgamma(gammad) - lgamma(brzAlpha))
        docBound += lgamma(sum(brzAlpha)) - lgamma(sum(gammad))

        docBound
      }.sum()

    // Bound component for prob(topic-term distributions):
    //   E[log p(beta | eta) - log q(beta | lambda)]
    val sumEta = eta * vocabSize
    val topicsPart = sum((eta - lambda) :* Elogbeta) +
      sum(lgamma(lambda) - lgamma(eta)) +
      sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*))))

    corpusPart + topicsPart
  }

  /**
   * Predicts the topic mixture distribution for each document (often called "theta" in the
   * literature).  Returns a vector of zeros for an empty document.
   *
   * This uses a variational approximation following Hoffman et al. (2010), where the approximate
   * distribution is called "gamma."  Technically, this method returns this approximation "gamma"
   * for each document.
   * @param documents documents to predict topic mixture distributions for
   * @return An RDD of (document ID, topic mixture distribution for document)
   */
  @Since("1.3.0")
  // TODO: declare in LDAModel and override once implemented in DistributedLDAModel
  def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = {
    // Double transpose because dirichletExpectation normalizes by row and we need to normalize
    // by topic (columns of lambda)
    val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
    val expElogbetaBc = documents.sparkContext.broadcast(expElogbeta)
    val docConcentrationBrz = this.docConcentration.toBreeze
    val gammaShape = this.gammaShape
    val k = this.k

    documents.map { case (id: Long, termCounts: Vector) =>
      if (termCounts.numNonzeros == 0) {
        (id, Vectors.zeros(k))
      } else {
        val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
          termCounts,
          expElogbetaBc.value,
          docConcentrationBrz,
          gammaShape,
          k)
        (id, Vectors.dense(normalize(gamma, 1.0).toArray))
      }
    }
  }

  /** Get a method usable as a UDF for [[topicDistributions()]] */
  private[spark] def getTopicDistributionMethod(sc: SparkContext): Vector => Vector = {
    val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
    val expElogbetaBc = sc.broadcast(expElogbeta)
    val docConcentrationBrz = this.docConcentration.toBreeze
    val gammaShape = this.gammaShape
    val k = this.k

    (termCounts: Vector) =>
      if (termCounts.numNonzeros == 0) {
        Vectors.zeros(k)
      } else {
        val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
          termCounts,
          expElogbetaBc.value,
          docConcentrationBrz,
          gammaShape,
          k)
        Vectors.dense(normalize(gamma, 1.0).toArray)
      }
  }

  /**
   * Predicts the topic mixture distribution for a document (often called "theta" in the
   * literature).  Returns a vector of zeros for an empty document.
   *
   * Note this means to allow quick query for single document. For batch documents, please refer
   * to [[topicDistributions()]] to avoid overhead.
   *
   * @param document document to predict topic mixture distributions for
   * @return topic mixture distribution for the document
   */
  @Since("2.0.0")
  def topicDistribution(document: Vector): Vector = {
    val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
    if (document.numNonzeros == 0) {
      Vectors.zeros(this.k)
    } else {
      val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
        document,
        expElogbeta,
        this.docConcentration.toBreeze,
        gammaShape,
        this.k)
      Vectors.dense(normalize(gamma, 1.0).toArray)
    }
  }

  /**
   * Java-friendly version of [[topicDistributions]]
   */
  @Since("1.4.1")
  def topicDistributions(
      documents: JavaPairRDD[java.lang.Long, Vector]): JavaPairRDD[java.lang.Long, Vector] = {
    val distributions = topicDistributions(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
    JavaPairRDD.fromRDD(distributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
  }

}

@Experimental
@Since("1.5.0")
object LocalLDAModel extends Loader[LocalLDAModel] {

  private object SaveLoadV1_0 {

    val thisFormatVersion = "1.0"

    val thisClassName = "org.apache.spark.mllib.clustering.LocalLDAModel"

    // Store the distribution of terms of each topic and the column index in topicsMatrix
    // as a Row in data.
    case class Data(topic: Vector, index: Int)

    def save(
        sc: SparkContext,
        path: String,
        topicsMatrix: Matrix,
        docConcentration: Vector,
        topicConcentration: Double,
        gammaShape: Double): Unit = {
      val sqlContext = SQLContext.getOrCreate(sc)
      import sqlContext.implicits._

      val k = topicsMatrix.numCols
      val metadata = compact(render
        (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
          ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows) ~
          ("docConcentration" -> docConcentration.toArray.toSeq) ~
          ("topicConcentration" -> topicConcentration) ~
          ("gammaShape" -> gammaShape)))
      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

      val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
      val topics = Range(0, k).map { topicInd =>
        Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)), topicInd)
      }.toSeq
      sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
    }

    def load(
        sc: SparkContext,
        path: String,
        docConcentration: Vector,
        topicConcentration: Double,
        gammaShape: Double): LocalLDAModel = {
      val dataPath = Loader.dataPath(path)
      val sqlContext = SQLContext.getOrCreate(sc)
      val dataFrame = sqlContext.read.parquet(dataPath)

      Loader.checkSchema[Data](dataFrame.schema)
      val topics = dataFrame.collect()
      val vocabSize = topics(0).getAs[Vector](0).size
      val k = topics.length

      val brzTopics = BDM.zeros[Double](vocabSize, k)
      topics.foreach { case Row(vec: Vector, ind: Int) =>
        brzTopics(::, ind) := vec.toBreeze
      }
      val topicsMat = Matrices.fromBreeze(brzTopics)

      new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape)
    }
  }

  @Since("1.5.0")
  override def load(sc: SparkContext, path: String): LocalLDAModel = {
    val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
    implicit val formats = DefaultFormats
    val expectedK = (metadata \ "k").extract[Int]
    val expectedVocabSize = (metadata \ "vocabSize").extract[Int]
    val docConcentration =
      Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray)
    val topicConcentration = (metadata \ "topicConcentration").extract[Double]
    val gammaShape = (metadata \ "gammaShape").extract[Double]
    val classNameV1_0 = SaveLoadV1_0.thisClassName

    val model = (loadedClassName, loadedVersion) match {
      case (className, "1.0") if className == classNameV1_0 =>
        SaveLoadV1_0.load(sc, path, docConcentration, topicConcentration, gammaShape)
      case _ => throw new Exception(
        s"LocalLDAModel.load did not recognize model with (className, format version):" +
          s"($loadedClassName, $loadedVersion).  Supported:\n" +
          s"  ($classNameV1_0, 1.0)")
    }

    val topicsMatrix = model.topicsMatrix
    require(expectedK == topicsMatrix.numCols,
      s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics")
    require(expectedVocabSize == topicsMatrix.numRows,
      s"LocalLDAModel requires $expectedVocabSize terms for each topic, " +
        s"but got ${topicsMatrix.numRows}")
    model
  }
}

/**
 * Distributed LDA model.
 * This model stores the inferred topics, the full training dataset, and the topic distributions.
 */
@Since("1.3.0")
class DistributedLDAModel private[clustering] (
    private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount],
    private[clustering] val globalTopicTotals: LDA.TopicCounts,
    @Since("1.3.0") val k: Int,
    @Since("1.3.0") val vocabSize: Int,
    @Since("1.5.0") override val docConcentration: Vector,
    @Since("1.5.0") override val topicConcentration: Double,
    private[spark] val iterationTimes: Array[Double],
    override protected[clustering] val gammaShape: Double = DistributedLDAModel.defaultGammaShape,
    private[spark] val checkpointFiles: Array[String] = Array.empty[String])
  extends LDAModel {

  import LDA._

  /**
   * Convert model to a local model.
   * The local model stores the inferred topics but not the topic distributions for training
   * documents.
   */
  @Since("1.3.0")
  def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
    gammaShape)

  /**
   * Inferred topics, where each topic is represented by a distribution over terms.
   * This is a matrix of size vocabSize x k, where each column is a topic.
   * No guarantees are given about the ordering of the topics.
   *
   * WARNING: This matrix is collected from an RDD. Beware memory usage when vocabSize, k are large.
   */
  @Since("1.3.0")
  override lazy val topicsMatrix: Matrix = {
    // Collect row-major topics
    val termTopicCounts: Array[(Int, TopicCounts)] =
      graph.vertices.filter(_._1 < 0).map { case (termIndex, cnts) =>
        (index2term(termIndex), cnts)
      }.collect()
    // Convert to Matrix
    val brzTopics = BDM.zeros[Double](vocabSize, k)
    termTopicCounts.foreach { case (term, cnts) =>
      var j = 0
      while (j < k) {
        brzTopics(term, j) = cnts(j)
        j += 1
      }
    }
    Matrices.fromBreeze(brzTopics)
  }

  @Since("1.3.0")
  override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = {
    val numTopics = k
    // Note: N_k is not needed to find the top terms, but it is needed to normalize weights
    //       to a distribution over terms.
    val N_k: TopicCounts = globalTopicTotals
    val topicsInQueues: Array[BoundedPriorityQueue[(Double, Int)]] =
      graph.vertices.filter(isTermVertex)
        .mapPartitions { termVertices =>
        // For this partition, collect the most common terms for each topic in queues:
        //  queues(topic) = queue of (term weight, term index).
        // Term weights are N_{wk} / N_k.
        val queues =
          Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Int)](maxTermsPerTopic))
        for ((termId, n_wk) <- termVertices) {
          var topic = 0
          while (topic < numTopics) {
            queues(topic) += (n_wk(topic) / N_k(topic) -> index2term(termId.toInt))
            topic += 1
          }
        }
        Iterator(queues)
      }.reduce { (q1, q2) =>
        q1.zip(q2).foreach { case (a, b) => a ++= b}
        q1
      }
    topicsInQueues.map { q =>
      val (termWeights, terms) = q.toArray.sortBy(-_._1).unzip
      (terms.toArray, termWeights.toArray)
    }
  }

  /**
   * Return the top documents for each topic
   *
   * @param maxDocumentsPerTopic  Maximum number of documents to collect for each topic.
   * @return  Array over topics.  Each element represent as a pair of matching arrays:
   *          (IDs for the documents, weights of the topic in these documents).
   *          For each topic, documents are sorted in order of decreasing topic weights.
   */
  @Since("1.5.0")
  def topDocumentsPerTopic(maxDocumentsPerTopic: Int): Array[(Array[Long], Array[Double])] = {
    val numTopics = k
    val topicsInQueues: Array[BoundedPriorityQueue[(Double, Long)]] =
      topicDistributions.mapPartitions { docVertices =>
        // For this partition, collect the most common docs for each topic in queues:
        //  queues(topic) = queue of (doc topic, doc ID).
        val queues =
          Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Long)](maxDocumentsPerTopic))
        for ((docId, docTopics) <- docVertices) {
          var topic = 0
          while (topic < numTopics) {
            queues(topic) += (docTopics(topic) -> docId)
            topic += 1
          }
        }
        Iterator(queues)
      }.treeReduce { (q1, q2) =>
        q1.zip(q2).foreach { case (a, b) => a ++= b }
        q1
      }
    topicsInQueues.map { q =>
      val (docTopics, docs) = q.toArray.sortBy(-_._1).unzip
      (docs.toArray, docTopics.toArray)
    }
  }

  /**
   * Return the top topic for each (doc, term) pair.  I.e., for each document, what is the most
   * likely topic generating each term?
   *
   * @return RDD of (doc ID, assignment of top topic index for each term),
   *         where the assignment is specified via a pair of zippable arrays
   *         (term indices, topic indices).  Note that terms will be omitted if not present in
   *         the document.
   */
  @Since("1.5.0")
  lazy val topicAssignments: RDD[(Long, Array[Int], Array[Int])] = {
    // For reference, compare the below code with the core part of EMLDAOptimizer.next().
    val eta = topicConcentration
    val W = vocabSize
    val alpha = docConcentration(0)
    val N_k = globalTopicTotals
    val sendMsg: EdgeContext[TopicCounts, TokenCount, (Array[Int], Array[Int])] => Unit =
      (edgeContext) => {
        // E-STEP: Compute gamma_{wjk} (smoothed topic distributions).
        val scaledTopicDistribution: TopicCounts =
          computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha)
        // For this (doc j, term w), send top topic k to doc vertex.
        val topTopic: Int = argmax(scaledTopicDistribution)
        val term: Int = index2term(edgeContext.dstId)
        edgeContext.sendToSrc((Array(term), Array(topTopic)))
      }
    val mergeMsg: ((Array[Int], Array[Int]), (Array[Int], Array[Int])) => (Array[Int], Array[Int]) =
      (terms_topics0, terms_topics1) => {
        (terms_topics0._1 ++ terms_topics1._1, terms_topics0._2 ++ terms_topics1._2)
      }
    // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
    val perDocAssignments =
      graph.aggregateMessages[(Array[Int], Array[Int])](sendMsg, mergeMsg).filter(isDocumentVertex)
    perDocAssignments.map { case (docID: Long, (terms: Array[Int], topics: Array[Int])) =>
      // TODO: Avoid zip, which is inefficient.
      val (sortedTerms, sortedTopics) = terms.zip(topics).sortBy(_._1).unzip
      (docID, sortedTerms.toArray, sortedTopics.toArray)
    }
  }

  /** Java-friendly version of [[topicAssignments]] */
  @Since("1.5.0")
  lazy val javaTopicAssignments: JavaRDD[(java.lang.Long, Array[Int], Array[Int])] = {
    topicAssignments.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Int])]].toJavaRDD()
  }

  // TODO
  // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???

  /**
   * Log likelihood of the observed tokens in the training set,
   * given the current parameter estimates:
   *  log P(docs | topics, topic distributions for docs, alpha, eta)
   *
   * Note:
   *  - This excludes the prior; for that, use [[logPrior]].
   *  - Even with [[logPrior]], this is NOT the same as the data log likelihood given the
   *    hyperparameters.
   */
  @Since("1.3.0")
  lazy val logLikelihood: Double = {
    // TODO: generalize this for asymmetric (non-scalar) alpha
    val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object
    val eta = this.topicConcentration
    assert(eta > 1.0)
    assert(alpha > 1.0)
    val N_k = globalTopicTotals
    val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0))
    // Edges: Compute token log probability from phi_{wk}, theta_{kj}.
    val sendMsg: EdgeContext[TopicCounts, TokenCount, Double] => Unit = (edgeContext) => {
      val N_wj = edgeContext.attr
      val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0)
      val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0)
      val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k
      val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0)
      val tokenLogLikelihood = N_wj * math.log(phi_wk.dot(theta_kj))
      edgeContext.sendToDst(tokenLogLikelihood)
    }
    graph.aggregateMessages[Double](sendMsg, _ + _)
      .map(_._2).fold(0.0)(_ + _)
  }

  /**
   * Log probability of the current parameter estimate:
   * log P(topics, topic distributions for docs | alpha, eta)
   */
  @Since("1.3.0")
  lazy val logPrior: Double = {
    // TODO: generalize this for asymmetric (non-scalar) alpha
    val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object
    val eta = this.topicConcentration
    // Term vertices: Compute phi_{wk}.  Use to compute prior log probability.
    // Doc vertex: Compute theta_{kj}.  Use to compute prior log probability.
    val N_k = globalTopicTotals
    val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0))
    val seqOp: (Double, (VertexId, TopicCounts)) => Double = {
      case (sumPrior: Double, vertex: (VertexId, TopicCounts)) =>
        if (isTermVertex(vertex)) {
          val N_wk = vertex._2
          val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0)
          val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k
          (eta - 1.0) * sum(phi_wk.map(math.log))
        } else {
          val N_kj = vertex._2
          val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0)
          val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0)
          (alpha - 1.0) * sum(theta_kj.map(math.log))
        }
    }
    graph.vertices.aggregate(0.0)(seqOp, _ + _)
  }

  /**
   * For each document in the training set, return the distribution over topics for that document
   * ("theta_doc").
   *
   * @return  RDD of (document ID, topic distribution) pairs
   */
  @Since("1.3.0")
  def topicDistributions: RDD[(Long, Vector)] = {
    graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) =>
      (docID.toLong, Vectors.fromBreeze(normalize(topicCounts, 1.0)))
    }
  }

  /**
   * Java-friendly version of [[topicDistributions]]
   */
  @Since("1.4.1")
  def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = {
    JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
  }

  /**
   * For each document, return the top k weighted topics for that document and their weights.
   * @return RDD of (doc ID, topic indices, topic weights)
   */
  @Since("1.5.0")
  def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = {
    graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) =>
      val topIndices = argtopk(topicCounts, k)
      val sumCounts = sum(topicCounts)
      val weights = if (sumCounts != 0) {
        topicCounts(topIndices) / sumCounts
      } else {
        topicCounts(topIndices)
      }
      (docID.toLong, topIndices.toArray, weights.toArray)
    }
  }

  /**
   * Java-friendly version of [[topTopicsPerDocument]]
   */
  @Since("1.5.0")
  def javaTopTopicsPerDocument(k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[Double])] = {
    val topics = topTopicsPerDocument(k)
    topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Double])]].toJavaRDD()
  }

  // TODO:
  // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???

  override protected def formatVersion = "1.0"

  @Since("1.5.0")
  override def save(sc: SparkContext, path: String): Unit = {
    // Note: This intentionally does not save checkpointFiles.
    DistributedLDAModel.SaveLoadV1_0.save(
      sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration,
      iterationTimes, gammaShape)
  }
}


@Experimental
@Since("1.5.0")
object DistributedLDAModel extends Loader[DistributedLDAModel] {

  /**
   * The [[DistributedLDAModel]] constructor's default arguments assume gammaShape = 100
   * to ensure equivalence in LDAModel.toLocal conversion.
   */
  private[clustering] val defaultGammaShape: Double = 100

  private object SaveLoadV1_0 {

    val thisFormatVersion = "1.0"

    val thisClassName = "org.apache.spark.mllib.clustering.DistributedLDAModel"

    // Store globalTopicTotals as a Vector.
    case class Data(globalTopicTotals: Vector)

    // Store each term and document vertex with an id and the topicWeights.
    case class VertexData(id: Long, topicWeights: Vector)

    // Store each edge with the source id, destination id and tokenCounts.
    case class EdgeData(srcId: Long, dstId: Long, tokenCounts: Double)

    def save(
        sc: SparkContext,
        path: String,
        graph: Graph[LDA.TopicCounts, LDA.TokenCount],
        globalTopicTotals: LDA.TopicCounts,
        k: Int,
        vocabSize: Int,
        docConcentration: Vector,
        topicConcentration: Double,
        iterationTimes: Array[Double],
        gammaShape: Double): Unit = {
      val sqlContext = SQLContext.getOrCreate(sc)
      import sqlContext.implicits._

      val metadata = compact(render
        (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
          ("k" -> k) ~ ("vocabSize" -> vocabSize) ~
          ("docConcentration" -> docConcentration.toArray.toSeq) ~
          ("topicConcentration" -> topicConcentration) ~
          ("iterationTimes" -> iterationTimes.toSeq) ~
          ("gammaShape" -> gammaShape)))
      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

      val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
      sc.parallelize(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).toDF()
        .write.parquet(newPath)

      val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
      graph.vertices.map { case (ind, vertex) =>
        VertexData(ind, Vectors.fromBreeze(vertex))
      }.toDF().write.parquet(verticesPath)

      val edgesPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
      graph.edges.map { case Edge(srcId, dstId, prop) =>
        EdgeData(srcId, dstId, prop)
      }.toDF().write.parquet(edgesPath)
    }

    def load(
        sc: SparkContext,
        path: String,
        vocabSize: Int,
        docConcentration: Vector,
        topicConcentration: Double,
        iterationTimes: Array[Double],
        gammaShape: Double): DistributedLDAModel = {
      val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
      val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
      val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
      val sqlContext = SQLContext.getOrCreate(sc)
      val dataFrame = sqlContext.read.parquet(dataPath)
      val vertexDataFrame = sqlContext.read.parquet(vertexDataPath)
      val edgeDataFrame = sqlContext.read.parquet(edgeDataPath)

      Loader.checkSchema[Data](dataFrame.schema)
      Loader.checkSchema[VertexData](vertexDataFrame.schema)
      Loader.checkSchema[EdgeData](edgeDataFrame.schema)
      val globalTopicTotals: LDA.TopicCounts =
        dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector
      val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.rdd.map {
        case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector)
      }

      val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.rdd.map {
        case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop)
      }
      val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)

      new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
        docConcentration, topicConcentration, iterationTimes, gammaShape)
    }

  }

  @Since("1.5.0")
  override def load(sc: SparkContext, path: String): DistributedLDAModel = {
    val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
    implicit val formats = DefaultFormats
    val expectedK = (metadata \ "k").extract[Int]
    val vocabSize = (metadata \ "vocabSize").extract[Int]
    val docConcentration =
      Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray)
    val topicConcentration = (metadata \ "topicConcentration").extract[Double]
    val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
    val gammaShape = (metadata \ "gammaShape").extract[Double]
    val classNameV1_0 = SaveLoadV1_0.thisClassName

    val model = (loadedClassName, loadedVersion) match {
      case (className, "1.0") if className == classNameV1_0 =>
        DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, docConcentration,
          topicConcentration, iterationTimes.toArray, gammaShape)
      case _ => throw new Exception(
        s"DistributedLDAModel.load did not recognize model with (className, format version):" +
          s"($loadedClassName, $loadedVersion).  Supported: ($classNameV1_0, 1.0)")
    }

    require(model.vocabSize == vocabSize,
      s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize")
    require(model.docConcentration == docConcentration,
      s"DistributedLDAModel requires $docConcentration docConcentration, " +
        s"got ${model.docConcentration} docConcentration")
    require(model.topicConcentration == topicConcentration,
      s"DistributedLDAModel requires $topicConcentration docConcentration, " +
        s"got ${model.topicConcentration} docConcentration")
    require(expectedK == model.k,
      s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics")
    model
  }

}