diff options
author | Yuhao Yang <hhbyyh@gmail.com> | 2016-04-06 11:36:26 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-06 11:36:26 -0700 |
commit | 8cffcb60deb82d04a5c6e144ec9927f6f7addc8b (patch) | |
tree | 681ff2c5b61076aabe3f52f7e40d8099893fb736 | |
parent | db0b06c6ea7412266158b1c710bdc8ca30e26430 (diff) | |
download | spark-8cffcb60deb82d04a5c6e144ec9927f6f7addc8b.tar.gz spark-8cffcb60deb82d04a5c6e144ec9927f6f7addc8b.tar.bz2 spark-8cffcb60deb82d04a5c6e144ec9927f6f7addc8b.zip |
[SPARK-14322][MLLIB] Use treeAggregate instead of reduce in OnlineLDAOptimizer
## What changes were proposed in this pull request?
jira: https://issues.apache.org/jira/browse/SPARK-14322
OnlineLDAOptimizer uses RDD.reduce in two places where it could use treeAggregate. This can cause scalability issues. This should be an easy fix.
This is also a bug since it modifies the first argument to reduce, so we should use aggregate or treeAggregate.
See this line: https://github.com/apache/spark/blob/f12f11e578169b47e3f8b18b299948c0670ba585/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala#L452
and a few lines below it.
## How was this patch tested?
unit tests
Author: Yuhao Yang <hhbyyh@gmail.com>
Closes #12106 from hhbyyh/ldaTreeReduce.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 7491ab0d51..2b404a8651 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -451,10 +451,11 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } Iterator((stat, gammaPart)) } - val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _) + val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))( + _ += _, _ += _) expElogbetaBc.unpersist() val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( - stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*) + stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) val batchResult = statsSum :* expElogbeta.t // Note that this is an optimization to avoid batch.count |