aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDB Tsai <dbtsai@alpinenow.com>2014-10-29 10:14:53 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-29 10:14:53 -0700
commit51ce997355465fc5c29d0e49b92f9bae0bab90ed (patch)
treed6c2b06140d87641b7b5ee56a2d4dccd7a7eb11a /mllib
parent1559495dd961d299299a27aae2cb940e8c6697c5 (diff)
downloadspark-51ce997355465fc5c29d0e49b92f9bae0bab90ed.tar.gz
spark-51ce997355465fc5c29d0e49b92f9bae0bab90ed.tar.bz2
spark-51ce997355465fc5c29d0e49b92f9bae0bab90ed.zip
[SPARK-4129][MLlib] Performance tuning in MultivariateOnlineSummarizer
In MultivariateOnlineSummarizer, breeze's activeIterator is used to loop through the nonZero elements in the vector. However, activeIterator doesn't perform well due to lots of overhead. In this PR, native while loop is used for both DenseVector and SparseVector. The benchmark result with 20 executors using mnist8m dataset: Before: DenseVector: 48.2 seconds SparseVector: 16.3 seconds After: DenseVector: 17.8 seconds SparseVector: 11.2 seconds Since MultivariateOnlineSummarizer is used in several places, the overall performance gain in mllib library will be significant with this PR. Author: DB Tsai <dbtsai@alpinenow.com> Closes #2992 from dbtsai/SPARK-4129 and squashes the following commits: b99db6c [DB Tsai] fixed java.lang.ArrayIndexOutOfBoundsException 2b5e882 [DB Tsai] small refactoring ebe3e74 [DB Tsai] First commit
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala25
1 files changed, 21 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 3025d4837c..fab7c4405c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.stat
import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.linalg.{Vectors, Vector}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
/**
* :: DeveloperApi ::
@@ -72,9 +72,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")
- sample.toBreeze.activeIterator.foreach {
- case (_, 0.0) => // Skip explicit zero elements.
- case (i, value) =>
+ @inline def update(i: Int, value: Double) = {
+ if (value != 0.0) {
if (currMax(i) < value) {
currMax(i) = value
}
@@ -89,6 +88,24 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currL1(i) += math.abs(value)
nnz(i) += 1.0
+ }
+ }
+
+ sample match {
+ case dv: DenseVector => {
+ var j = 0
+ while (j < dv.size) {
+ update(j, dv.values(j))
+ j += 1
+ }
+ }
+ case sv: SparseVector =>
+ var j = 0
+ while (j < sv.indices.size) {
+ update(sv.indices(j), sv.values(j))
+ j += 1
+ }
+ case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
totalCnt += 1