aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala32
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala121
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala24
3 files changed, 105 insertions, 72 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 60ab2aaa8f..c6d5fe5bc6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -76,6 +76,15 @@ sealed trait Vector extends Serializable {
def copy: Vector = {
throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.")
}
+
+ /**
+ * Applies a function `f` to all the active elements of dense and sparse vector.
+ *
+ * @param f the function takes two parameters where the first parameter is the index of
+ * the vector with type `Int`, and the second parameter is the corresponding value
+ * with type `Double`.
+ */
+ private[spark] def foreachActive(f: (Int, Double) => Unit)
}
/**
@@ -273,6 +282,17 @@ class DenseVector(val values: Array[Double]) extends Vector {
override def copy: DenseVector = {
new DenseVector(values.clone())
}
+
+ private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
+ var i = 0
+ val localValuesSize = values.size
+ val localValues = values
+
+ while (i < localValuesSize) {
+ f(i, localValues(i))
+ i += 1
+ }
+ }
}
/**
@@ -309,4 +329,16 @@ class SparseVector(
}
private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)
+
+ private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
+ var i = 0
+ val localValuesSize = values.size
+ val localIndices = indices
+ val localValues = values
+
+ while (i < localValuesSize) {
+ f(localIndices(i), localValues(i))
+ i += 1
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 654479ac2d..fcc2a14879 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -17,10 +17,8 @@
package org.apache.spark.mllib.stat
-import breeze.linalg.{DenseVector => BDV}
-
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
/**
* :: DeveloperApi ::
@@ -40,37 +38,14 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector
class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
private var n = 0
- private var currMean: BDV[Double] = _
- private var currM2n: BDV[Double] = _
- private var currM2: BDV[Double] = _
- private var currL1: BDV[Double] = _
+ private var currMean: Array[Double] = _
+ private var currM2n: Array[Double] = _
+ private var currM2: Array[Double] = _
+ private var currL1: Array[Double] = _
private var totalCnt: Long = 0
- private var nnz: BDV[Double] = _
- private var currMax: BDV[Double] = _
- private var currMin: BDV[Double] = _
-
- /**
- * Adds input value to position i.
- */
- private[this] def add(i: Int, value: Double) = {
- if (value != 0.0) {
- if (currMax(i) < value) {
- currMax(i) = value
- }
- if (currMin(i) > value) {
- currMin(i) = value
- }
-
- val prevMean = currMean(i)
- val diff = value - prevMean
- currMean(i) = prevMean + diff / (nnz(i) + 1.0)
- currM2n(i) += (value - currMean(i)) * diff
- currM2(i) += value * value
- currL1(i) += math.abs(value)
-
- nnz(i) += 1.0
- }
- }
+ private var nnz: Array[Double] = _
+ private var currMax: Array[Double] = _
+ private var currMin: Array[Double] = _
/**
* Add a new sample to this summarizer, and update the statistical summary.
@@ -83,33 +58,36 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(sample.size > 0, s"Vector should have dimension larger than zero.")
n = sample.size
- currMean = BDV.zeros[Double](n)
- currM2n = BDV.zeros[Double](n)
- currM2 = BDV.zeros[Double](n)
- currL1 = BDV.zeros[Double](n)
- nnz = BDV.zeros[Double](n)
- currMax = BDV.fill(n)(Double.MinValue)
- currMin = BDV.fill(n)(Double.MaxValue)
+ currMean = Array.ofDim[Double](n)
+ currM2n = Array.ofDim[Double](n)
+ currM2 = Array.ofDim[Double](n)
+ currL1 = Array.ofDim[Double](n)
+ nnz = Array.ofDim[Double](n)
+ currMax = Array.fill[Double](n)(Double.MinValue)
+ currMin = Array.fill[Double](n)(Double.MaxValue)
}
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")
- sample match {
- case dv: DenseVector => {
- var j = 0
- while (j < dv.size) {
- add(j, dv.values(j))
- j += 1
+ sample.foreachActive { (index, value) =>
+ if (value != 0.0) {
+ if (currMax(index) < value) {
+ currMax(index) = value
}
- }
- case sv: SparseVector =>
- var j = 0
- while (j < sv.indices.size) {
- add(sv.indices(j), sv.values(j))
- j += 1
+ if (currMin(index) > value) {
+ currMin(index) = value
}
- case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
+
+ val prevMean = currMean(index)
+ val diff = value - prevMean
+ currMean(index) = prevMean + diff / (nnz(index) + 1.0)
+ currM2n(index) += (value - currMean(index)) * diff
+ currM2(index) += value * value
+ currL1(index) += math.abs(value)
+
+ nnz(index) += 1.0
+ }
}
totalCnt += 1
@@ -152,14 +130,14 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
}
} else if (totalCnt == 0 && other.totalCnt != 0) {
this.n = other.n
- this.currMean = other.currMean.copy
- this.currM2n = other.currM2n.copy
- this.currM2 = other.currM2.copy
- this.currL1 = other.currL1.copy
+ this.currMean = other.currMean.clone
+ this.currM2n = other.currM2n.clone
+ this.currM2 = other.currM2.clone
+ this.currL1 = other.currL1.clone
this.totalCnt = other.totalCnt
- this.nnz = other.nnz.copy
- this.currMax = other.currMax.copy
- this.currMin = other.currMin.copy
+ this.nnz = other.nnz.clone
+ this.currMax = other.currMax.clone
+ this.currMin = other.currMin.clone
}
this
}
@@ -167,19 +145,19 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
override def mean: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- val realMean = BDV.zeros[Double](n)
+ val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * (nnz(i) / totalCnt)
i += 1
}
- Vectors.fromBreeze(realMean)
+ Vectors.dense(realMean)
}
override def variance: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- val realVariance = BDV.zeros[Double](n)
+ val realVariance = Array.ofDim[Double](n)
val denominator = totalCnt - 1.0
@@ -194,8 +172,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
i += 1
}
}
-
- Vectors.fromBreeze(realVariance)
+ Vectors.dense(realVariance)
}
override def count: Long = totalCnt
@@ -203,7 +180,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
override def numNonzeros: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- Vectors.fromBreeze(nnz)
+ Vectors.dense(nnz)
}
override def max: Vector = {
@@ -214,7 +191,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
- Vectors.fromBreeze(currMax)
+ Vectors.dense(currMax)
}
override def min: Vector = {
@@ -225,25 +202,25 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
- Vectors.fromBreeze(currMin)
+ Vectors.dense(currMin)
}
override def normL2: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- val realMagnitude = BDV.zeros[Double](n)
+ val realMagnitude = Array.ofDim[Double](n)
var i = 0
while (i < currM2.size) {
realMagnitude(i) = math.sqrt(currM2(i))
i += 1
}
-
- Vectors.fromBreeze(realMagnitude)
+ Vectors.dense(realMagnitude)
}
override def normL1: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- Vectors.fromBreeze(currL1)
+
+ Vectors.dense(currL1)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 59cd85eab2..9492f604af 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -173,4 +173,28 @@ class VectorsSuite extends FunSuite {
val v = Vectors.fromBreeze(x(::, 0))
assert(v.size === x.rows)
}
+
+ test("foreachActive") {
+ val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0)
+ val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0)))
+
+ val dvMap = scala.collection.mutable.Map[Int, Double]()
+ dv.foreachActive { (index, value) =>
+ dvMap.put(index, value)
+ }
+ assert(dvMap.size === 4)
+ assert(dvMap.get(0) === Some(0.0))
+ assert(dvMap.get(1) === Some(1.2))
+ assert(dvMap.get(2) === Some(3.1))
+ assert(dvMap.get(3) === Some(0.0))
+
+ val svMap = scala.collection.mutable.Map[Int, Double]()
+ sv.foreachActive { (index, value) =>
+ svMap.put(index, value)
+ }
+ assert(svMap.size === 3)
+ assert(svMap.get(1) === Some(1.2))
+ assert(svMap.get(2) === Some(3.1))
+ assert(svMap.get(3) === Some(0.0))
+ }
}