aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala2
2 files changed, 3 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 964f419d12..7a2a7a35a9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -231,9 +231,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def numNonzeros: Vector = {
- require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
+ require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- Vectors.dense(weightSum)
+ Vectors.dense(nnz.map(_.toDouble))
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 165a3f314a..797e84fcc7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -237,7 +237,7 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
absTol 1E-10, "mean mismatch")
assert(summarizer.variance ~== Vectors.dense(Array(0.17657142857, 1.645115714, 2.42057142857))
absTol 1E-8, "variance mismatch")
- assert(summarizer.numNonzeros ~== Vectors.dense(Array(0.3, 0.5, 0.4))
+ assert(summarizer.numNonzeros ~== Vectors.dense(Array(3.0, 4.0, 3.0))
absTol 1E-10, "numNonzeros mismatch")
assert(summarizer.max ~== Vectors.dense(Array(0.0, 1.7, 1.3)) absTol 1E-10, "max mismatch")
assert(summarizer.min ~== Vectors.dense(Array(-0.8, -1.2, -1.7)) absTol 1E-10, "min mismatch")