aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-13 15:36:03 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-13 15:36:03 -0800
commit4b1c77cbf59ccc752bc0d0291df3550cbfbe730c (patch)
treeae6cdb176de0d8fc739d9aefda1db2266047ad54
parent685bdd2b7e584c84e7d39e40de2d5f30c5388cb5 (diff)
downloadspark-4b1c77cbf59ccc752bc0d0291df3550cbfbe730c.tar.gz
spark-4b1c77cbf59ccc752bc0d0291df3550cbfbe730c.tar.bz2
spark-4b1c77cbf59ccc752bc0d0291df3550cbfbe730c.zip
[branch-1.1][SPARK-4355] OnlineSummarizer doesn't merge mean correctly
andrewor14 This backports the bug fix in #3220 . It would be good if we can get it in 1.1.1. But this is minor. Author: Xiangrui Meng <meng@databricks.com> Closes #3251 from mengxr/SPARK-4355-1.1 and squashes the following commits: 33886b6 [Xiangrui Meng] Merge remote-tracking branch 'apache/branch-1.1' into SPARK-4355-1.1 91fe1a3 [Xiangrui Meng] fix OnlineSummarizer.merge when other.mean is zero
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala20
1 files changed, 9 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 7d845c4436..f23eb5b96d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -104,21 +104,19 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val deltaMean: BDV[Double] = currMean - other.currMean
var i = 0
while (i < n) {
- // merge mean together
- if (other.currMean(i) != 0.0) {
+ if (nnz(i) + other.nnz(i) != 0.0) {
+ // merge mean together
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
(nnz(i) + other.nnz(i))
- }
- // merge m2n together
- if (nnz(i) + other.nnz(i) != 0.0) {
+ // merge m2n together
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
(nnz(i) + other.nnz(i))
- }
- if (currMax(i) < other.currMax(i)) {
- currMax(i) = other.currMax(i)
- }
- if (currMin(i) > other.currMin(i)) {
- currMin(i) = other.currMin(i)
+ if (currMax(i) < other.currMax(i)) {
+ currMax(i) = other.currMax(i)
+ }
+ if (currMin(i) > other.currMin(i)) {
+ currMin(i) = other.currMin(i)
+ }
}
i += 1
}