aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorWeichenXu <WeichenXu123@outlook.com>2016-07-23 12:32:30 +0100
committerSean Owen <sowen@cloudera.com>2016-07-23 12:32:30 +0100
commit25db51675f43048d61ced8221dcb4885cc5143c1 (patch)
tree7fab2d674faeb1b7f941a91f367f6f2b0e841329 /mllib
parente10b8741d86a0a625d28bcb1c654736a260be85e (diff)
downloadspark-25db51675f43048d61ced8221dcb4885cc5143c1.tar.gz
spark-25db51675f43048d61ced8221dcb4885cc5143c1.tar.bz2
spark-25db51675f43048d61ced8221dcb4885cc5143c1.zip
[SPARK-16561][MLLIB] fix multivarOnlineSummary min/max bug
## What changes were proposed in this pull request? renaming var names to make code more clear: nnz => weightSum weightSum => totalWeightSum and add a new member vector `nnz` (not `nnz` in previous code, which renamed to `weightSum`) to count each dimensions non-zero value number. using `nnz` which I added above instead of `weightSum` when calculating min/max so that it fix several numerical error in some extreme case. ## How was this patch tested? A new testcase added. Author: WeichenXu <WeichenXu123@outlook.com> Closes #14216 from WeichenXu123/multivarOnlineSummary.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala63
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala25
2 files changed, 60 insertions, 28 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index d4de0fd7d5..964f419d12 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -47,9 +47,10 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
private var currM2: Array[Double] = _
private var currL1: Array[Double] = _
private var totalCnt: Long = 0
- private var weightSum: Double = 0.0
+ private var totalWeightSum: Double = 0.0
private var weightSquareSum: Double = 0.0
- private var nnz: Array[Double] = _
+ private var weightSum: Array[Double] = _
+ private var nnz: Array[Long] = _
private var currMax: Array[Double] = _
private var currMin: Array[Double] = _
@@ -74,7 +75,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currM2n = Array.ofDim[Double](n)
currM2 = Array.ofDim[Double](n)
currL1 = Array.ofDim[Double](n)
- nnz = Array.ofDim[Double](n)
+ weightSum = Array.ofDim[Double](n)
+ nnz = Array.ofDim[Long](n)
currMax = Array.fill[Double](n)(Double.MinValue)
currMin = Array.fill[Double](n)(Double.MaxValue)
}
@@ -86,7 +88,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val localCurrM2n = currM2n
val localCurrM2 = currM2
val localCurrL1 = currL1
- val localNnz = nnz
+ val localWeightSum = weightSum
+ val localNumNonzeros = nnz
val localCurrMax = currMax
val localCurrMin = currMin
instance.foreachActive { (index, value) =>
@@ -100,16 +103,17 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val prevMean = localCurrMean(index)
val diff = value - prevMean
- localCurrMean(index) = prevMean + weight * diff / (localNnz(index) + weight)
+ localCurrMean(index) = prevMean + weight * diff / (localWeightSum(index) + weight)
localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff
localCurrM2(index) += weight * value * value
localCurrL1(index) += weight * math.abs(value)
- localNnz(index) += weight
+ localWeightSum(index) += weight
+ localNumNonzeros(index) += 1
}
}
- weightSum += weight
+ totalWeightSum += weight
weightSquareSum += weight * weight
totalCnt += 1
this
@@ -124,17 +128,18 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
def merge(other: MultivariateOnlineSummarizer): this.type = {
- if (this.weightSum != 0.0 && other.weightSum != 0.0) {
+ if (this.totalWeightSum != 0.0 && other.totalWeightSum != 0.0) {
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
s"Expecting $n but got ${other.n}.")
totalCnt += other.totalCnt
- weightSum += other.weightSum
+ totalWeightSum += other.totalWeightSum
weightSquareSum += other.weightSquareSum
var i = 0
while (i < n) {
- val thisNnz = nnz(i)
- val otherNnz = other.nnz(i)
+ val thisNnz = weightSum(i)
+ val otherNnz = other.weightSum(i)
val totalNnz = thisNnz + otherNnz
+ val totalCnnz = nnz(i) + other.nnz(i)
if (totalNnz != 0.0) {
val deltaMean = other.currMean(i) - currMean(i)
// merge mean together
@@ -149,18 +154,20 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currMax(i) = math.max(currMax(i), other.currMax(i))
currMin(i) = math.min(currMin(i), other.currMin(i))
}
- nnz(i) = totalNnz
+ weightSum(i) = totalNnz
+ nnz(i) = totalCnnz
i += 1
}
- } else if (weightSum == 0.0 && other.weightSum != 0.0) {
+ } else if (totalWeightSum == 0.0 && other.totalWeightSum != 0.0) {
this.n = other.n
this.currMean = other.currMean.clone()
this.currM2n = other.currM2n.clone()
this.currM2 = other.currM2.clone()
this.currL1 = other.currL1.clone()
this.totalCnt = other.totalCnt
- this.weightSum = other.weightSum
+ this.totalWeightSum = other.totalWeightSum
this.weightSquareSum = other.weightSquareSum
+ this.weightSum = other.weightSum.clone()
this.nnz = other.nnz.clone()
this.currMax = other.currMax.clone()
this.currMin = other.currMin.clone()
@@ -174,12 +181,12 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def mean: Vector = {
- require(weightSum > 0, s"Nothing has been added to this summarizer.")
+ require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
- realMean(i) = currMean(i) * (nnz(i) / weightSum)
+ realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum)
i += 1
}
Vectors.dense(realMean)
@@ -191,11 +198,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def variance: Vector = {
- require(weightSum > 0, s"Nothing has been added to this summarizer.")
+ require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
val realVariance = Array.ofDim[Double](n)
- val denominator = weightSum - (weightSquareSum / weightSum)
+ val denominator = totalWeightSum - (weightSquareSum / totalWeightSum)
// Sample variance is computed, if the denominator is less than 0, the variance is just 0.
if (denominator > 0.0) {
@@ -203,8 +210,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
var i = 0
val len = currM2n.length
while (i < len) {
- realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) *
- (weightSum - nnz(i)) / weightSum) / denominator
+ realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
+ (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
i += 1
}
}
@@ -224,9 +231,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def numNonzeros: Vector = {
- require(weightSum > 0, s"Nothing has been added to this summarizer.")
+ require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
- Vectors.dense(nnz)
+ Vectors.dense(weightSum)
}
/**
@@ -235,11 +242,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def max: Vector = {
- require(weightSum > 0, s"Nothing has been added to this summarizer.")
+ require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
var i = 0
while (i < n) {
- if ((nnz(i) < weightSum) && (currMax(i) < 0.0)) currMax(i) = 0.0
+ if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
Vectors.dense(currMax)
@@ -251,11 +258,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def min: Vector = {
- require(weightSum > 0, s"Nothing has been added to this summarizer.")
+ require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
var i = 0
while (i < n) {
- if ((nnz(i) < weightSum) && (currMin(i) > 0.0)) currMin(i) = 0.0
+ if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
Vectors.dense(currMin)
@@ -267,7 +274,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.2.0")
override def normL2: Vector = {
- require(weightSum > 0, s"Nothing has been added to this summarizer.")
+ require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
val realMagnitude = Array.ofDim[Double](n)
@@ -286,7 +293,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.2.0")
override def normL1: Vector = {
- require(weightSum > 0, s"Nothing has been added to this summarizer.")
+ require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
Vectors.dense(currL1)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index b6d41db69b..165a3f314a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -245,4 +245,29 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
absTol 1E-8, "normL2 mismatch")
assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch")
}
+
+ test("test min/max with weighted samples (SPARK-16561)") {
+ val summarizer1 = new MultivariateOnlineSummarizer()
+ .add(Vectors.dense(10.0, -10.0), 1e10)
+ .add(Vectors.dense(0.0, 0.0), 1e-7)
+
+ val summarizer2 = new MultivariateOnlineSummarizer()
+ summarizer2.add(Vectors.dense(10.0, -10.0), 1e10)
+ for (i <- 1 to 100) {
+ summarizer2.add(Vectors.dense(0.0, 0.0), 1e-7)
+ }
+
+ val summarizer3 = new MultivariateOnlineSummarizer()
+ for (i <- 1 to 100) {
+ summarizer3.add(Vectors.dense(0.0, 0.0), 1e-7)
+ }
+ summarizer3.add(Vectors.dense(10.0, -10.0), 1e10)
+
+ assert(summarizer1.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
+ assert(summarizer1.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
+ assert(summarizer2.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
+ assert(summarizer2.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
+ assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
+ assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
+ }
}