aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorDB Tsai <dbtsai@alpinenow.com>2014-12-03 22:31:39 +0800
committerXiangrui Meng <meng@databricks.com>2014-12-03 22:31:39 +0800
commitd00542987ed80635782dcc826fc0bdbf434fff10 (patch)
tree587fc5c7cfc8756beccc55b25eb63006f2e0457c /mllib/src
parent7fc49ed91168999d24ae7b4cc46fbb4ec87febc1 (diff)
downloadspark-d00542987ed80635782dcc826fc0bdbf434fff10.tar.gz
spark-d00542987ed80635782dcc826fc0bdbf434fff10.tar.bz2
spark-d00542987ed80635782dcc826fc0bdbf434fff10.zip
[SPARK-4717][MLlib] Optimize BLAS library to avoid de-reference multiple times in loop
Have a local reference to `values` and `indices` array in the `Vector` object so JVM can locate the value with one operation call. See `SPARK-4581` for similar optimization, and the bytecode analysis. Author: DB Tsai <dbtsai@alpinenow.com> Closes #3577 from dbtsai/blasopt and squashes the following commits: 62d38c4 [DB Tsai] formating 0316cef [DB Tsai] first commit
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala99
1 files changed, 60 insertions, 39 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 89539e600f..8c4c9c6cf6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -72,17 +72,21 @@ private[spark] object BLAS extends Serializable with Logging {
* y += a * x
*/
private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = {
- val nnz = x.indices.size
+ val xValues = x.values
+ val xIndices = x.indices
+ val yValues = y.values
+ val nnz = xIndices.size
+
if (a == 1.0) {
var k = 0
while (k < nnz) {
- y.values(x.indices(k)) += x.values(k)
+ yValues(xIndices(k)) += xValues(k)
k += 1
}
} else {
var k = 0
while (k < nnz) {
- y.values(x.indices(k)) += a * x.values(k)
+ yValues(xIndices(k)) += a * xValues(k)
k += 1
}
}
@@ -119,11 +123,15 @@ private[spark] object BLAS extends Serializable with Logging {
* dot(x, y)
*/
private def dot(x: SparseVector, y: DenseVector): Double = {
- val nnz = x.indices.size
+ val xValues = x.values
+ val xIndices = x.indices
+ val yValues = y.values
+ val nnz = xIndices.size
+
var sum = 0.0
var k = 0
while (k < nnz) {
- sum += x.values(k) * y.values(x.indices(k))
+ sum += xValues(k) * yValues(xIndices(k))
k += 1
}
sum
@@ -133,19 +141,24 @@ private[spark] object BLAS extends Serializable with Logging {
* dot(x, y)
*/
private def dot(x: SparseVector, y: SparseVector): Double = {
+ val xValues = x.values
+ val xIndices = x.indices
+ val yValues = y.values
+ val yIndices = y.indices
+ val nnzx = xIndices.size
+ val nnzy = yIndices.size
+
var kx = 0
- val nnzx = x.indices.size
var ky = 0
- val nnzy = y.indices.size
var sum = 0.0
// y catching x
while (kx < nnzx && ky < nnzy) {
- val ix = x.indices(kx)
- while (ky < nnzy && y.indices(ky) < ix) {
+ val ix = xIndices(kx)
+ while (ky < nnzy && yIndices(ky) < ix) {
ky += 1
}
- if (ky < nnzy && y.indices(ky) == ix) {
- sum += x.values(kx) * y.values(ky)
+ if (ky < nnzy && yIndices(ky) == ix) {
+ sum += xValues(kx) * yValues(ky)
ky += 1
}
kx += 1
@@ -163,21 +176,25 @@ private[spark] object BLAS extends Serializable with Logging {
case dy: DenseVector =>
x match {
case sx: SparseVector =>
+ val sxIndices = sx.indices
+ val sxValues = sx.values
+ val dyValues = dy.values
+ val nnz = sxIndices.size
+
var i = 0
var k = 0
- val nnz = sx.indices.size
while (k < nnz) {
- val j = sx.indices(k)
+ val j = sxIndices(k)
while (i < j) {
- dy.values(i) = 0.0
+ dyValues(i) = 0.0
i += 1
}
- dy.values(i) = sx.values(k)
+ dyValues(i) = sxValues(k)
i += 1
k += 1
}
while (i < n) {
- dy.values(i) = 0.0
+ dyValues(i) = 0.0
i += 1
}
case dx: DenseVector =>
@@ -311,6 +328,8 @@ private[spark] object BLAS extends Serializable with Logging {
s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB")
val Avals = A.values
+ val Bvals = B.values
+ val Cvals = C.values
val Arows = if (!transA) A.rowIndices else A.colPtrs
val Acols = if (!transA) A.colPtrs else A.rowIndices
@@ -327,11 +346,11 @@ private[spark] object BLAS extends Serializable with Logging {
val indEnd = Arows(rowCounterForA + 1)
var sum = 0.0
while (i < indEnd) {
- sum += Avals(i) * B.values(Bstart + Acols(i))
+ sum += Avals(i) * Bvals(Bstart + Acols(i))
i += 1
}
val Cindex = Cstart + rowCounterForA
- C.values(Cindex) = beta * C.values(Cindex) + sum * alpha
+ Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
rowCounterForA += 1
}
colCounterForB += 1
@@ -349,7 +368,7 @@ private[spark] object BLAS extends Serializable with Logging {
i += 1
}
val Cindex = Cstart + rowCounter
- C.values(Cindex) = beta * C.values(Cindex) + sum * alpha
+ Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
rowCounter += 1
}
colCounterForB += 1
@@ -357,7 +376,7 @@ private[spark] object BLAS extends Serializable with Logging {
}
} else {
// Scale matrix first if `beta` is not equal to 0.0
- if (beta != 0.0){
+ if (beta != 0.0) {
f2jBLAS.dscal(C.values.length, beta, C.values, 1)
}
// Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
@@ -371,9 +390,9 @@ private[spark] object BLAS extends Serializable with Logging {
while (colCounterForA < kA) {
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)
- val Bval = B.values(Bstart + colCounterForA) * alpha
- while (i < indEnd){
- C.values(Cstart + Arows(i)) += Avals(i) * Bval
+ val Bval = Bvals(Bstart + colCounterForA) * alpha
+ while (i < indEnd) {
+ Cvals(Cstart + Arows(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
@@ -384,12 +403,12 @@ private[spark] object BLAS extends Serializable with Logging {
while (colCounterForB < nB) {
var colCounterForA = 0 // The column of A to multiply with the row of B
val Cstart = colCounterForB * mA
- while (colCounterForA < kA){
+ while (colCounterForA < kA) {
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)
val Bval = B(colCounterForB, colCounterForA) * alpha
- while (i < indEnd){
- C.values(Cstart + Arows(i)) += Avals(i) * Bval
+ while (i < indEnd) {
+ Cvals(Cstart + Arows(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
@@ -484,41 +503,43 @@ private[spark] object BLAS extends Serializable with Logging {
beta: Double,
y: DenseVector): Unit = {
- val mA: Int = if(!trans) A.numRows else A.numCols
- val nA: Int = if(!trans) A.numCols else A.numRows
+ val xValues = x.values
+ val yValues = y.values
+
+ val mA: Int = if (!trans) A.numRows else A.numCols
+ val nA: Int = if (!trans) A.numCols else A.numRows
val Avals = A.values
val Arows = if (!trans) A.rowIndices else A.colPtrs
val Acols = if (!trans) A.colPtrs else A.rowIndices
-
// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
- if (trans){
+ if (trans) {
var rowCounter = 0
- while (rowCounter < mA){
+ while (rowCounter < mA) {
var i = Arows(rowCounter)
val indEnd = Arows(rowCounter + 1)
var sum = 0.0
- while(i < indEnd){
- sum += Avals(i) * x.values(Acols(i))
+ while (i < indEnd) {
+ sum += Avals(i) * xValues(Acols(i))
i += 1
}
- y.values(rowCounter) = beta * y.values(rowCounter) + sum * alpha
+ yValues(rowCounter) = beta * yValues(rowCounter) + sum * alpha
rowCounter += 1
}
} else {
// Scale vector first if `beta` is not equal to 0.0
- if (beta != 0.0){
+ if (beta != 0.0) {
scal(beta, y)
}
// Perform matrix-vector multiplication and add to y
var colCounterForA = 0
- while (colCounterForA < nA){
+ while (colCounterForA < nA) {
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)
- val xVal = x.values(colCounterForA) * alpha
- while (i < indEnd){
+ val xVal = xValues(colCounterForA) * alpha
+ while (i < indEnd) {
val rowIndex = Arows(i)
- y.values(rowIndex) += Avals(i) * xVal
+ yValues(rowIndex) += Avals(i) * xVal
i += 1
}
colCounterForA += 1