aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala5
1 files changed, 3 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 7491ab0d51..2b404a8651 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -451,10 +451,11 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
}
Iterator((stat, gammaPart))
}
- val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _)
+ val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))(
+ _ += _, _ += _)
expElogbetaBc.unpersist()
val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat(
- stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*)
+ stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*)
val batchResult = statsSum :* expElogbeta.t
// Note that this is an optimization to avoid batch.count