aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2016-04-06 11:36:26 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-06 11:36:26 -0700
commit8cffcb60deb82d04a5c6e144ec9927f6f7addc8b (patch)
tree681ff2c5b61076aabe3f52f7e40d8099893fb736 /mllib
parentdb0b06c6ea7412266158b1c710bdc8ca30e26430 (diff)
downloadspark-8cffcb60deb82d04a5c6e144ec9927f6f7addc8b.tar.gz
spark-8cffcb60deb82d04a5c6e144ec9927f6f7addc8b.tar.bz2
spark-8cffcb60deb82d04a5c6e144ec9927f6f7addc8b.zip
[SPARK-14322][MLLIB] Use treeAggregate instead of reduce in OnlineLDAOptimizer
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-14322 OnlineLDAOptimizer uses RDD.reduce in two places where it could use treeAggregate. This can cause scalability issues. This should be an easy fix. This is also a bug since it modifies the first argument to reduce, so we should use aggregate or treeAggregate. See this line: https://github.com/apache/spark/blob/f12f11e578169b47e3f8b18b299948c0670ba585/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala#L452 and a few lines below it. ## How was this patch tested? unit tests Author: Yuhao Yang <hhbyyh@gmail.com> Closes #12106 from hhbyyh/ldaTreeReduce.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala5
1 files changed, 3 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 7491ab0d51..2b404a8651 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -451,10 +451,11 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
}
Iterator((stat, gammaPart))
}
- val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _)
+ val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))(
+ _ += _, _ += _)
expElogbetaBc.unpersist()
val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat(
- stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*)
+ stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*)
val batchResult = statsSum :* expElogbeta.t
// Note that this is an optimization to avoid batch.count