diff options
Diffstat (limited to 'mllib/src')
3 files changed, 105 insertions, 72 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 60ab2aaa8f..c6d5fe5bc6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -76,6 +76,15 @@ sealed trait Vector extends Serializable { def copy: Vector = { throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") } + + /** + * Applies a function `f` to all the active elements of dense and sparse vector. + * + * @param f the function takes two parameters where the first parameter is the index of + * the vector with type `Int`, and the second parameter is the corresponding value + * with type `Double`. + */ + private[spark] def foreachActive(f: (Int, Double) => Unit) } /** @@ -273,6 +282,17 @@ class DenseVector(val values: Array[Double]) extends Vector { override def copy: DenseVector = { new DenseVector(values.clone()) } + + private[spark] override def foreachActive(f: (Int, Double) => Unit) = { + var i = 0 + val localValuesSize = values.size + val localValues = values + + while (i < localValuesSize) { + f(i, localValues(i)) + i += 1 + } + } } /** @@ -309,4 +329,16 @@ class SparseVector( } private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + + private[spark] override def foreachActive(f: (Int, Double) => Unit) = { + var i = 0 + val localValuesSize = values.size + val localIndices = indices + val localValues = values + + while (i < localValuesSize) { + f(localIndices(i), localValues(i)) + i += 1 + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 654479ac2d..fcc2a14879 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -17,10 +17,8 @@ package org.apache.spark.mllib.stat -import breeze.linalg.{DenseVector => BDV} - import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * :: DeveloperApi :: @@ -40,37 +38,14 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { private var n = 0 - private var currMean: BDV[Double] = _ - private var currM2n: BDV[Double] = _ - private var currM2: BDV[Double] = _ - private var currL1: BDV[Double] = _ + private var currMean: Array[Double] = _ + private var currM2n: Array[Double] = _ + private var currM2: Array[Double] = _ + private var currL1: Array[Double] = _ private var totalCnt: Long = 0 - private var nnz: BDV[Double] = _ - private var currMax: BDV[Double] = _ - private var currMin: BDV[Double] = _ - - /** - * Adds input value to position i. - */ - private[this] def add(i: Int, value: Double) = { - if (value != 0.0) { - if (currMax(i) < value) { - currMax(i) = value - } - if (currMin(i) > value) { - currMin(i) = value - } - - val prevMean = currMean(i) - val diff = value - prevMean - currMean(i) = prevMean + diff / (nnz(i) + 1.0) - currM2n(i) += (value - currMean(i)) * diff - currM2(i) += value * value - currL1(i) += math.abs(value) - - nnz(i) += 1.0 - } - } + private var nnz: Array[Double] = _ + private var currMax: Array[Double] = _ + private var currMin: Array[Double] = _ /** * Add a new sample to this summarizer, and update the statistical summary. @@ -83,33 +58,36 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(sample.size > 0, s"Vector should have dimension larger than zero.") n = sample.size - currMean = BDV.zeros[Double](n) - currM2n = BDV.zeros[Double](n) - currM2 = BDV.zeros[Double](n) - currL1 = BDV.zeros[Double](n) - nnz = BDV.zeros[Double](n) - currMax = BDV.fill(n)(Double.MinValue) - currMin = BDV.fill(n)(Double.MaxValue) + currMean = Array.ofDim[Double](n) + currM2n = Array.ofDim[Double](n) + currM2 = Array.ofDim[Double](n) + currL1 = Array.ofDim[Double](n) + nnz = Array.ofDim[Double](n) + currMax = Array.fill[Double](n)(Double.MinValue) + currMin = Array.fill[Double](n)(Double.MaxValue) } require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - sample match { - case dv: DenseVector => { - var j = 0 - while (j < dv.size) { - add(j, dv.values(j)) - j += 1 + sample.foreachActive { (index, value) => + if (value != 0.0) { + if (currMax(index) < value) { + currMax(index) = value } - } - case sv: SparseVector => - var j = 0 - while (j < sv.indices.size) { - add(sv.indices(j), sv.values(j)) - j += 1 + if (currMin(index) > value) { + currMin(index) = value } - case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + + val prevMean = currMean(index) + val diff = value - prevMean + currMean(index) = prevMean + diff / (nnz(index) + 1.0) + currM2n(index) += (value - currMean(index)) * diff + currM2(index) += value * value + currL1(index) += math.abs(value) + + nnz(index) += 1.0 + } } totalCnt += 1 @@ -152,14 +130,14 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } } else if (totalCnt == 0 && other.totalCnt != 0) { this.n = other.n - this.currMean = other.currMean.copy - this.currM2n = other.currM2n.copy - this.currM2 = other.currM2.copy - this.currL1 = other.currL1.copy + this.currMean = other.currMean.clone + this.currM2n = other.currM2n.clone + this.currM2 = other.currM2.clone + this.currL1 = other.currL1.clone this.totalCnt = other.totalCnt - this.nnz = other.nnz.copy - this.currMax = other.currMax.copy - this.currMin = other.currMin.copy + this.nnz = other.nnz.clone + this.currMax = other.currMax.clone + this.currMin = other.currMin.clone } this } @@ -167,19 +145,19 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S override def mean: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realMean = BDV.zeros[Double](n) + val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { realMean(i) = currMean(i) * (nnz(i) / totalCnt) i += 1 } - Vectors.fromBreeze(realMean) + Vectors.dense(realMean) } override def variance: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realVariance = BDV.zeros[Double](n) + val realVariance = Array.ofDim[Double](n) val denominator = totalCnt - 1.0 @@ -194,8 +172,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S i += 1 } } - - Vectors.fromBreeze(realVariance) + Vectors.dense(realVariance) } override def count: Long = totalCnt @@ -203,7 +180,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S override def numNonzeros: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - Vectors.fromBreeze(nnz) + Vectors.dense(nnz) } override def max: Vector = { @@ -214,7 +191,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } - Vectors.fromBreeze(currMax) + Vectors.dense(currMax) } override def min: Vector = { @@ -225,25 +202,25 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } - Vectors.fromBreeze(currMin) + Vectors.dense(currMin) } override def normL2: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realMagnitude = BDV.zeros[Double](n) + val realMagnitude = Array.ofDim[Double](n) var i = 0 while (i < currM2.size) { realMagnitude(i) = math.sqrt(currM2(i)) i += 1 } - - Vectors.fromBreeze(realMagnitude) + Vectors.dense(realMagnitude) } override def normL1: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - Vectors.fromBreeze(currL1) + + Vectors.dense(currL1) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 59cd85eab2..9492f604af 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -173,4 +173,28 @@ class VectorsSuite extends FunSuite { val v = Vectors.fromBreeze(x(::, 0)) assert(v.size === x.rows) } + + test("foreachActive") { + val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0) + val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0))) + + val dvMap = scala.collection.mutable.Map[Int, Double]() + dv.foreachActive { (index, value) => + dvMap.put(index, value) + } + assert(dvMap.size === 4) + assert(dvMap.get(0) === Some(0.0)) + assert(dvMap.get(1) === Some(1.2)) + assert(dvMap.get(2) === Some(3.1)) + assert(dvMap.get(3) === Some(0.0)) + + val svMap = scala.collection.mutable.Map[Int, Double]() + sv.foreachActive { (index, value) => + svMap.put(index, value) + } + assert(svMap.size === 3) + assert(svMap.get(1) === Some(1.2)) + assert(svMap.get(2) === Some(3.1)) + assert(svMap.get(3) === Some(0.0)) + } } |