diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 7491ab0d51..2b404a8651 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -451,10 +451,11 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } Iterator((stat, gammaPart)) } - val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _) + val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))( + _ += _, _ += _) expElogbetaBc.unpersist() val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( - stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*) + stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) val batchResult = statsSum :* expElogbeta.t // Note that this is an optimization to avoid batch.count |