diff options
author | Xiangrui Meng <meng@databricks.com> | 2014-11-12 01:50:11 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-12 01:50:11 -0800 |
commit | 84324fbcb987db6e10e435f463eacace1bae43e2 (patch) | |
tree | 7637baffa919381bda06a472d46619aedef6d531 | |
parent | faeb41de215d3ac567ce72a43ab242ad433ca93e (diff) | |
download | spark-84324fbcb987db6e10e435f463eacace1bae43e2.tar.gz spark-84324fbcb987db6e10e435f463eacace1bae43e2.tar.bz2 spark-84324fbcb987db6e10e435f463eacace1bae43e2.zip |
[SPARK-4355][MLLIB] fix OnlineSummarizer.merge when other.mean is zero
See inline comment about the bug. I also did some code clean-up. dbtsai I moved `update` to a private method of `MultivariateOnlineSummarizer`. I don't think it will cause performance regression, but it would be great if you have some time to test.
Author: Xiangrui Meng <meng@databricks.com>
Closes #3220 from mengxr/SPARK-4355 and squashes the following commits:
5ef601f [Xiangrui Meng] fix OnlineSummarizer.merge when other.mean is zero and some code clean-up
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala | 85 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala | 11 |
2 files changed, 51 insertions, 45 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index fab7c4405c..654479ac2d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -50,6 +50,29 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var currMin: BDV[Double] = _ /** + * Adds input value to position i. + */ + private[this] def add(i: Int, value: Double) = { + if (value != 0.0) { + if (currMax(i) < value) { + currMax(i) = value + } + if (currMin(i) > value) { + currMin(i) = value + } + + val prevMean = currMean(i) + val diff = value - prevMean + currMean(i) = prevMean + diff / (nnz(i) + 1.0) + currM2n(i) += (value - currMean(i)) * diff + currM2(i) += value * value + currL1(i) += math.abs(value) + + nnz(i) += 1.0 + } + } + + /** * Add a new sample to this summarizer, and update the statistical summary. * * @param sample The sample in dense/sparse vector format to be added into this summarizer. @@ -72,37 +95,18 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - @inline def update(i: Int, value: Double) = { - if (value != 0.0) { - if (currMax(i) < value) { - currMax(i) = value - } - if (currMin(i) > value) { - currMin(i) = value - } - - val tmpPrevMean = currMean(i) - currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0) - currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean) - currM2(i) += value * value - currL1(i) += math.abs(value) - - nnz(i) += 1.0 - } - } - sample match { case dv: DenseVector => { var j = 0 while (j < dv.size) { - update(j, dv.values(j)) + add(j, dv.values(j)) j += 1 } } case sv: SparseVector => var j = 0 while (j < sv.indices.size) { - update(sv.indices(j), sv.values(j)) + add(sv.indices(j), sv.values(j)) j += 1 } case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) @@ -124,37 +128,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got ${other.n}.") totalCnt += other.totalCnt - val deltaMean: BDV[Double] = currMean - other.currMean var i = 0 while (i < n) { - // merge mean together - if (other.currMean(i) != 0.0) { - currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) / - (nnz(i) + other.nnz(i)) - } - // merge m2n together - if (nnz(i) + other.nnz(i) != 0.0) { - currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) / - (nnz(i) + other.nnz(i)) - } - // merge m2 together - if (nnz(i) + other.nnz(i) != 0.0) { + val thisNnz = nnz(i) + val otherNnz = other.nnz(i) + val totalNnz = thisNnz + otherNnz + if (totalNnz != 0.0) { + val deltaMean = other.currMean(i) - currMean(i) + // merge mean together + currMean(i) += deltaMean * otherNnz / totalNnz + // merge m2n together + currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz + // merge m2 together currM2(i) += other.currM2(i) - } - // merge l1 together - if (nnz(i) + other.nnz(i) != 0.0) { + // merge l1 together currL1(i) += other.currL1(i) + // merge max and min + currMax(i) = math.max(currMax(i), other.currMax(i)) + currMin(i) = math.min(currMin(i), other.currMin(i)) } - - if (currMax(i) < other.currMax(i)) { - currMax(i) = other.currMax(i) - } - if (currMin(i) > other.currMin(i)) { - currMin(i) = other.currMin(i) - } + nnz(i) = totalNnz i += 1 } - nnz += other.nnz } else if (totalCnt == 0 && other.totalCnt != 0) { this.n = other.n this.currMean = other.currMean.copy diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 1e94152491..23b0eec865 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -208,4 +208,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") } + + test("merging summarizer when one side has zero mean (SPARK-4355)") { + val s0 = new MultivariateOnlineSummarizer() + .add(Vectors.dense(2.0)) + .add(Vectors.dense(2.0)) + val s1 = new MultivariateOnlineSummarizer() + .add(Vectors.dense(1.0)) + .add(Vectors.dense(-1.0)) + s0.merge(s1) + assert(s0.mean(0) ~== 1.0 absTol 1e-14) + } } |