diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index b6d41db69b..165a3f314a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -245,4 +245,29 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { absTol 1E-8, "normL2 mismatch") assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch") } + + test("test min/max with weighted samples (SPARK-16561)") { + val summarizer1 = new MultivariateOnlineSummarizer() + .add(Vectors.dense(10.0, -10.0), 1e10) + .add(Vectors.dense(0.0, 0.0), 1e-7) + + val summarizer2 = new MultivariateOnlineSummarizer() + summarizer2.add(Vectors.dense(10.0, -10.0), 1e10) + for (i <- 1 to 100) { + summarizer2.add(Vectors.dense(0.0, 0.0), 1e-7) + } + + val summarizer3 = new MultivariateOnlineSummarizer() + for (i <- 1 to 100) { + summarizer3.add(Vectors.dense(0.0, 0.0), 1e-7) + } + summarizer3.add(Vectors.dense(10.0, -10.0), 1e10) + + assert(summarizer1.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) + assert(summarizer1.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) + assert(summarizer2.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) + assert(summarizer2.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) + assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) + assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) + } } |