aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
diff options
context:
space:
mode:
authorWeichenXu <WeichenXu123@outlook.com>2016-07-23 12:32:30 +0100
committerSean Owen <sowen@cloudera.com>2016-07-23 12:32:30 +0100
commit25db51675f43048d61ced8221dcb4885cc5143c1 (patch)
tree7fab2d674faeb1b7f941a91f367f6f2b0e841329 /mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
parente10b8741d86a0a625d28bcb1c654736a260be85e (diff)
downloadspark-25db51675f43048d61ced8221dcb4885cc5143c1.tar.gz
spark-25db51675f43048d61ced8221dcb4885cc5143c1.tar.bz2
spark-25db51675f43048d61ced8221dcb4885cc5143c1.zip
[SPARK-16561][MLLIB] fix multivarOnlineSummary min/max bug
## What changes were proposed in this pull request? renaming var names to make code more clear: nnz => weightSum weightSum => totalWeightSum and add a new member vector `nnz` (not `nnz` in previous code, which renamed to `weightSum`) to count each dimensions non-zero value number. using `nnz` which I added above instead of `weightSum` when calculating min/max so that it fix several numerical error in some extreme case. ## How was this patch tested? A new testcase added. Author: WeichenXu <WeichenXu123@outlook.com> Closes #14216 from WeichenXu123/multivarOnlineSummary.
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala25
1 files changed, 25 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index b6d41db69b..165a3f314a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -245,4 +245,29 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
absTol 1E-8, "normL2 mismatch")
assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch")
}
+
+ test("test min/max with weighted samples (SPARK-16561)") {
+ val summarizer1 = new MultivariateOnlineSummarizer()
+ .add(Vectors.dense(10.0, -10.0), 1e10)
+ .add(Vectors.dense(0.0, 0.0), 1e-7)
+
+ val summarizer2 = new MultivariateOnlineSummarizer()
+ summarizer2.add(Vectors.dense(10.0, -10.0), 1e10)
+ for (i <- 1 to 100) {
+ summarizer2.add(Vectors.dense(0.0, 0.0), 1e-7)
+ }
+
+ val summarizer3 = new MultivariateOnlineSummarizer()
+ for (i <- 1 to 100) {
+ summarizer3.add(Vectors.dense(0.0, 0.0), 1e-7)
+ }
+ summarizer3.add(Vectors.dense(10.0, -10.0), 1e10)
+
+ assert(summarizer1.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
+ assert(summarizer1.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
+ assert(summarizer2.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
+ assert(summarizer2.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
+ assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
+ assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
+ }
}