aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDB Tsai <dbtsai@alpinenow.com>2014-11-21 18:15:07 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-21 18:15:07 -0800
commitb5d17ef10e2509d9886c660945449a89750c8116 (patch)
treee646e6711ccc3b0c7d465c8b27f02999ffdeebb4 /mllib
parentce95bd8e130b2c7688b94be40683bdd90d86012d (diff)
downloadspark-b5d17ef10e2509d9886c660945449a89750c8116.tar.gz
spark-b5d17ef10e2509d9886c660945449a89750c8116.tar.bz2
spark-b5d17ef10e2509d9886c660945449a89750c8116.zip
[SPARK-4431][MLlib] Implement efficient foreachActive for dense and sparse vector
Previously, we were using Breeze's activeIterator to access the non-zero elements in dense/sparse vector. Due to the overhead, we switched back to native `while loop` in #SPARK-4129. However, #SPARK-4129 requires de-reference the dv.values/sv.values in each access to the value, which is very expensive. Also, in MultivariateOnlineSummarizer, we're using Breeze's dense vector to store the partial stats, and this is very expensive compared with using primitive scala array. In this PR, efficient foreachActive is implemented to unify the code path for dense and sparse vector operation which makes codebase easier to maintain. Breeze dense vector is replaced by primitive array to reduce the overhead further. Benchmarking with mnist8m dataset on single JVM with first 200 samples loaded in memory, and repeating 5000 times. Before change: Sparse Vector - 30.02 Dense Vector - 38.27 With this PR: Sparse Vector - 6.29 Dense Vector - 11.72 Author: DB Tsai <dbtsai@alpinenow.com> Closes #3288 from dbtsai/activeIterator and squashes the following commits: 844b0e6 [DB Tsai] formating 03dd693 [DB Tsai] futher performance tunning. 1907ae1 [DB Tsai] address feedback 98448bb [DB Tsai] Made the override final, and had a local copy of variables which made the accessing a single step operation. c0cbd5a [DB Tsai] fix a bug 6441f92 [DB Tsai] Finished SPARK-4431
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala32
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala121
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala24
3 files changed, 105 insertions, 72 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 60ab2aaa8f..c6d5fe5bc6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -76,6 +76,15 @@ sealed trait Vector extends Serializable {
def copy: Vector = {
throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.")
}
+
+ /**
+ * Applies a function `f` to all the active elements of dense and sparse vector.
+ *
+ * @param f the function takes two parameters where the first parameter is the index of
+ * the vector with type `Int`, and the second parameter is the corresponding value
+ * with type `Double`.
+ */
+ private[spark] def foreachActive(f: (Int, Double) => Unit)
}
/**
@@ -273,6 +282,17 @@ class DenseVector(val values: Array[Double]) extends Vector {
override def copy: DenseVector = {
new DenseVector(values.clone())
}
+
+ private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
+ var i = 0
+ val localValuesSize = values.size
+ val localValues = values
+
+ while (i < localValuesSize) {
+ f(i, localValues(i))
+ i += 1
+ }
+ }
}
/**
@@ -309,4 +329,16 @@ class SparseVector(
}
private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)
+
+ private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
+ var i = 0
+ val localValuesSize = values.size
+ val localIndices = indices
+ val localValues = values
+
+ while (i < localValuesSize) {
+ f(localIndices(i), localValues(i))
+ i += 1
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 654479ac2d..fcc2a14879 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -17,10 +17,8 @@
package org.apache.spark.mllib.stat
-import breeze.linalg.{DenseVector => BDV}
-
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
/**
* :: DeveloperApi ::
@@ -40,37 +38,14 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector
class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
private var n = 0
- private var currMean: BDV[Double] = _
- private var currM2n: BDV[Double] = _
- private var currM2: BDV[Double] = _
- private var currL1: BDV[Double] = _
+ private var currMean: Array[Double] = _
+ private var currM2n: Array[Double] = _
+ private var currM2: Array[Double] = _
+ private var currL1: Array[Double] = _
private var totalCnt: Long = 0
- private var nnz: BDV[Double] = _
- private var currMax: BDV[Double] = _
- private var currMin: BDV[Double] = _
-
- /**
- * Adds input value to position i.
- */
- private[this] def add(i: Int, value: Double) = {
- if (value != 0.0) {
- if (currMax(i) < value) {
- currMax(i) = value
- }
- if (currMin(i) > value) {
- currMin(i) = value
- }
-
- val prevMean = currMean(i)
- val diff = value - prevMean
- currMean(i) = prevMean + diff / (nnz(i) + 1.0)
- currM2n(i) += (value - currMean(i)) * diff
- currM2(i) += value * value
- currL1(i) += math.abs(value)
-
- nnz(i) += 1.0
- }
- }
+ private var nnz: Array[Double] = _
+ private var currMax: Array[Double] = _
+ private var currMin: Array[Double] = _
/**
* Add a new sample to this summarizer, and update the statistical summary.
@@ -83,33 +58,36 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(sample.size > 0, s"Vector should have dimension larger than zero.")
n = sample.size
- currMean = BDV.zeros[Double](n)
- currM2n = BDV.zeros[Double](n)
- currM2 = BDV.zeros[Double](n)
- currL1 = BDV.zeros[Double](n)
- nnz = BDV.zeros[Double](n)
- currMax = BDV.fill(n)(Double.MinValue)
- currMin = BDV.fill(n)(Double.MaxValue)
+ currMean = Array.ofDim[Double](n)
+ currM2n = Array.ofDim[Double](n)
+ currM2 = Array.ofDim[Double](n)
+ currL1 = Array.ofDim[Double](n)
+ nnz = Array.ofDim[Double](n)
+ currMax = Array.fill[Double](n)(Double.MinValue)
+ currMin = Array.fill[Double](n)(Double.MaxValue)
}
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")
- sample match {
- case dv: DenseVector => {
- var j = 0
- while (j < dv.size) {
- add(j, dv.values(j))
- j += 1
+ sample.foreachActive { (index, value) =>
+ if (value != 0.0) {
+ if (currMax(index) < value) {
+ currMax(index) = value
}
- }
- case sv: SparseVector =>
- var j = 0
- while (j < sv.indices.size) {
- add(sv.indices(j), sv.values(j))
- j += 1
+ if (currMin(index) > value) {
+ currMin(index) = value
}
- case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
+
+ val prevMean = currMean(index)
+ val diff = value - prevMean
+ currMean(index) = prevMean + diff / (nnz(index) + 1.0)
+ currM2n(index) += (value - currMean(index)) * diff
+ currM2(index) += value * value
+ currL1(index) += math.abs(value)
+
+ nnz(index) += 1.0
+ }
}
totalCnt += 1
@@ -152,14 +130,14 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
}
} else if (totalCnt == 0 && other.totalCnt != 0) {
this.n = other.n
- this.currMean = other.currMean.copy
- this.currM2n = other.currM2n.copy
- this.currM2 = other.currM2.copy
- this.currL1 = other.currL1.copy
+ this.currMean = other.currMean.clone
+ this.currM2n = other.currM2n.clone
+ this.currM2 = other.currM2.clone
+ this.currL1 = other.currL1.clone
this.totalCnt = other.totalCnt
- this.nnz = other.nnz.copy
- this.currMax = other.currMax.copy
- this.currMin = other.currMin.copy
+ this.nnz = other.nnz.clone
+ this.currMax = other.currMax.clone
+ this.currMin = other.currMin.clone
}
this
}
@@ -167,19 +145,19 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
override def mean: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- val realMean = BDV.zeros[Double](n)
+ val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * (nnz(i) / totalCnt)
i += 1
}
- Vectors.fromBreeze(realMean)
+ Vectors.dense(realMean)
}
override def variance: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- val realVariance = BDV.zeros[Double](n)
+ val realVariance = Array.ofDim[Double](n)
val denominator = totalCnt - 1.0
@@ -194,8 +172,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
i += 1
}
}
-
- Vectors.fromBreeze(realVariance)
+ Vectors.dense(realVariance)
}
override def count: Long = totalCnt
@@ -203,7 +180,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
override def numNonzeros: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- Vectors.fromBreeze(nnz)
+ Vectors.dense(nnz)
}
override def max: Vector = {
@@ -214,7 +191,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
- Vectors.fromBreeze(currMax)
+ Vectors.dense(currMax)
}
override def min: Vector = {
@@ -225,25 +202,25 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
- Vectors.fromBreeze(currMin)
+ Vectors.dense(currMin)
}
override def normL2: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- val realMagnitude = BDV.zeros[Double](n)
+ val realMagnitude = Array.ofDim[Double](n)
var i = 0
while (i < currM2.size) {
realMagnitude(i) = math.sqrt(currM2(i))
i += 1
}
-
- Vectors.fromBreeze(realMagnitude)
+ Vectors.dense(realMagnitude)
}
override def normL1: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- Vectors.fromBreeze(currL1)
+
+ Vectors.dense(currL1)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 59cd85eab2..9492f604af 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -173,4 +173,28 @@ class VectorsSuite extends FunSuite {
val v = Vectors.fromBreeze(x(::, 0))
assert(v.size === x.rows)
}
+
+ test("foreachActive") {
+ val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0)
+ val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0)))
+
+ val dvMap = scala.collection.mutable.Map[Int, Double]()
+ dv.foreachActive { (index, value) =>
+ dvMap.put(index, value)
+ }
+ assert(dvMap.size === 4)
+ assert(dvMap.get(0) === Some(0.0))
+ assert(dvMap.get(1) === Some(1.2))
+ assert(dvMap.get(2) === Some(3.1))
+ assert(dvMap.get(3) === Some(0.0))
+
+ val svMap = scala.collection.mutable.Map[Int, Double]()
+ sv.foreachActive { (index, value) =>
+ svMap.put(index, value)
+ }
+ assert(svMap.size === 3)
+ assert(svMap.get(1) === Some(1.2))
+ assert(svMap.get(2) === Some(3.1))
+ assert(svMap.get(3) === Some(0.0))
+ }
}