aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-12 01:50:11 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-12 01:50:11 -0800
commit84324fbcb987db6e10e435f463eacace1bae43e2 (patch)
tree7637baffa919381bda06a472d46619aedef6d531 /mllib
parentfaeb41de215d3ac567ce72a43ab242ad433ca93e (diff)
downloadspark-84324fbcb987db6e10e435f463eacace1bae43e2.tar.gz
spark-84324fbcb987db6e10e435f463eacace1bae43e2.tar.bz2
spark-84324fbcb987db6e10e435f463eacace1bae43e2.zip
[SPARK-4355][MLLIB] fix OnlineSummarizer.merge when other.mean is zero
See inline comment about the bug. I also did some code clean-up. dbtsai I moved `update` to a private method of `MultivariateOnlineSummarizer`. I don't think it will cause performance regression, but it would be great if you have some time to test. Author: Xiangrui Meng <meng@databricks.com> Closes #3220 from mengxr/SPARK-4355 and squashes the following commits: 5ef601f [Xiangrui Meng] fix OnlineSummarizer.merge when other.mean is zero and some code clean-up
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala85
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala11
2 files changed, 51 insertions, 45 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index fab7c4405c..654479ac2d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -50,6 +50,29 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
private var currMin: BDV[Double] = _
/**
+ * Adds input value to position i.
+ */
+ private[this] def add(i: Int, value: Double) = {
+ if (value != 0.0) {
+ if (currMax(i) < value) {
+ currMax(i) = value
+ }
+ if (currMin(i) > value) {
+ currMin(i) = value
+ }
+
+ val prevMean = currMean(i)
+ val diff = value - prevMean
+ currMean(i) = prevMean + diff / (nnz(i) + 1.0)
+ currM2n(i) += (value - currMean(i)) * diff
+ currM2(i) += value * value
+ currL1(i) += math.abs(value)
+
+ nnz(i) += 1.0
+ }
+ }
+
+ /**
* Add a new sample to this summarizer, and update the statistical summary.
*
* @param sample The sample in dense/sparse vector format to be added into this summarizer.
@@ -72,37 +95,18 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")
- @inline def update(i: Int, value: Double) = {
- if (value != 0.0) {
- if (currMax(i) < value) {
- currMax(i) = value
- }
- if (currMin(i) > value) {
- currMin(i) = value
- }
-
- val tmpPrevMean = currMean(i)
- currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
- currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
- currM2(i) += value * value
- currL1(i) += math.abs(value)
-
- nnz(i) += 1.0
- }
- }
-
sample match {
case dv: DenseVector => {
var j = 0
while (j < dv.size) {
- update(j, dv.values(j))
+ add(j, dv.values(j))
j += 1
}
}
case sv: SparseVector =>
var j = 0
while (j < sv.indices.size) {
- update(sv.indices(j), sv.values(j))
+ add(sv.indices(j), sv.values(j))
j += 1
}
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
@@ -124,37 +128,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
s"Expecting $n but got ${other.n}.")
totalCnt += other.totalCnt
- val deltaMean: BDV[Double] = currMean - other.currMean
var i = 0
while (i < n) {
- // merge mean together
- if (other.currMean(i) != 0.0) {
- currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
- (nnz(i) + other.nnz(i))
- }
- // merge m2n together
- if (nnz(i) + other.nnz(i) != 0.0) {
- currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
- (nnz(i) + other.nnz(i))
- }
- // merge m2 together
- if (nnz(i) + other.nnz(i) != 0.0) {
+ val thisNnz = nnz(i)
+ val otherNnz = other.nnz(i)
+ val totalNnz = thisNnz + otherNnz
+ if (totalNnz != 0.0) {
+ val deltaMean = other.currMean(i) - currMean(i)
+ // merge mean together
+ currMean(i) += deltaMean * otherNnz / totalNnz
+ // merge m2n together
+ currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
+ // merge m2 together
currM2(i) += other.currM2(i)
- }
- // merge l1 together
- if (nnz(i) + other.nnz(i) != 0.0) {
+ // merge l1 together
currL1(i) += other.currL1(i)
+ // merge max and min
+ currMax(i) = math.max(currMax(i), other.currMax(i))
+ currMin(i) = math.min(currMin(i), other.currMin(i))
}
-
- if (currMax(i) < other.currMax(i)) {
- currMax(i) = other.currMax(i)
- }
- if (currMin(i) > other.currMin(i)) {
- currMin(i) = other.currMin(i)
- }
+ nnz(i) = totalNnz
i += 1
}
- nnz += other.nnz
} else if (totalCnt == 0 && other.totalCnt != 0) {
this.n = other.n
this.currMean = other.currMean.copy
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 1e94152491..23b0eec865 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -208,4 +208,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch")
}
+
+ test("merging summarizer when one side has zero mean (SPARK-4355)") {
+ val s0 = new MultivariateOnlineSummarizer()
+ .add(Vectors.dense(2.0))
+ .add(Vectors.dense(2.0))
+ val s1 = new MultivariateOnlineSummarizer()
+ .add(Vectors.dense(1.0))
+ .add(Vectors.dense(-1.0))
+ s0.merge(s1)
+ assert(s0.mean(0) ~== 1.0 absTol 1e-14)
+ }
}