aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala25
1 files changed, 25 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index b6d41db69b..165a3f314a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -245,4 +245,29 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
absTol 1E-8, "normL2 mismatch")
assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch")
}
+
+ test("test min/max with weighted samples (SPARK-16561)") {
+ val summarizer1 = new MultivariateOnlineSummarizer()
+ .add(Vectors.dense(10.0, -10.0), 1e10)
+ .add(Vectors.dense(0.0, 0.0), 1e-7)
+
+ val summarizer2 = new MultivariateOnlineSummarizer()
+ summarizer2.add(Vectors.dense(10.0, -10.0), 1e10)
+ for (i <- 1 to 100) {
+ summarizer2.add(Vectors.dense(0.0, 0.0), 1e-7)
+ }
+
+ val summarizer3 = new MultivariateOnlineSummarizer()
+ for (i <- 1 to 100) {
+ summarizer3.add(Vectors.dense(0.0, 0.0), 1e-7)
+ }
+ summarizer3.add(Vectors.dense(10.0, -10.0), 1e10)
+
+ assert(summarizer1.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
+ assert(summarizer1.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
+ assert(summarizer2.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
+ assert(summarizer2.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
+ assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
+ assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
+ }
}