aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala25
1 files changed, 21 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 3025d4837c..fab7c4405c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.stat
import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.linalg.{Vectors, Vector}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
/**
* :: DeveloperApi ::
@@ -72,9 +72,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")
- sample.toBreeze.activeIterator.foreach {
- case (_, 0.0) => // Skip explicit zero elements.
- case (i, value) =>
+ @inline def update(i: Int, value: Double) = {
+ if (value != 0.0) {
if (currMax(i) < value) {
currMax(i) = value
}
@@ -89,6 +88,24 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currL1(i) += math.abs(value)
nnz(i) += 1.0
+ }
+ }
+
+ sample match {
+ case dv: DenseVector => {
+ var j = 0
+ while (j < dv.size) {
+ update(j, dv.values(j))
+ j += 1
+ }
+ }
+ case sv: SparseVector =>
+ var j = 0
+ while (j < sv.indices.size) {
+ update(sv.indices(j), sv.values(j))
+ j += 1
+ }
+ case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
totalCnt += 1