diff options
author | WeichenXu <WeichenXu123@outlook.com> | 2016-07-23 12:32:30 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-07-23 12:32:30 +0100 |
commit | 25db51675f43048d61ced8221dcb4885cc5143c1 (patch) | |
tree | 7fab2d674faeb1b7f941a91f367f6f2b0e841329 /mllib/src/test/scala/org | |
parent | e10b8741d86a0a625d28bcb1c654736a260be85e (diff) | |
download | spark-25db51675f43048d61ced8221dcb4885cc5143c1.tar.gz spark-25db51675f43048d61ced8221dcb4885cc5143c1.tar.bz2 spark-25db51675f43048d61ced8221dcb4885cc5143c1.zip |
[SPARK-16561][MLLIB] fix multivarOnlineSummary min/max bug
## What changes were proposed in this pull request?
renaming var names to make code more clear:
nnz => weightSum
weightSum => totalWeightSum
and add a new member vector `nnz` (not `nnz` in previous code, which renamed to `weightSum`) to count each dimensions non-zero value number.
using `nnz` which I added above instead of `weightSum` when calculating min/max so that it fix several numerical error in some extreme case.
## How was this patch tested?
A new testcase added.
Author: WeichenXu <WeichenXu123@outlook.com>
Closes #14216 from WeichenXu123/multivarOnlineSummary.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index b6d41db69b..165a3f314a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -245,4 +245,29 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { absTol 1E-8, "normL2 mismatch") assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch") } + + test("test min/max with weighted samples (SPARK-16561)") { + val summarizer1 = new MultivariateOnlineSummarizer() + .add(Vectors.dense(10.0, -10.0), 1e10) + .add(Vectors.dense(0.0, 0.0), 1e-7) + + val summarizer2 = new MultivariateOnlineSummarizer() + summarizer2.add(Vectors.dense(10.0, -10.0), 1e10) + for (i <- 1 to 100) { + summarizer2.add(Vectors.dense(0.0, 0.0), 1e-7) + } + + val summarizer3 = new MultivariateOnlineSummarizer() + for (i <- 1 to 100) { + summarizer3.add(Vectors.dense(0.0, 0.0), 1e-7) + } + summarizer3.add(Vectors.dense(10.0, -10.0), 1e10) + + assert(summarizer1.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) + assert(summarizer1.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) + assert(summarizer2.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) + assert(summarizer2.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) + assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) + assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) + } } |