diff options
author | Bjarne Fruergaard <bwahlgreen@gmail.com> | 2016-09-29 15:39:57 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-09-29 15:39:57 -0700 |
commit | 29396e7d1483d027960b9a1bed47008775c4253e (patch) | |
tree | 61ffe766fd7463cc7776a7b034eacee293ddb7cc /mllib/src/main | |
parent | 4ecc648ad713f9d618adf0406b5d39981779059d (diff) | |
download | spark-29396e7d1483d027960b9a1bed47008775c4253e.tar.gz spark-29396e7d1483d027960b9a1bed47008775c4253e.tar.bz2 spark-29396e7d1483d027960b9a1bed47008775c4253e.zip |
[SPARK-17721][MLLIB][ML] Fix for multiplying transposed SparseMatrix with SparseVector
## What changes were proposed in this pull request?
* changes the implementation of gemv with transposed SparseMatrix and SparseVector both in mllib-local and mllib (identical)
* adds a test that was failing before this change, but succeeds with these changes.
The problem in the previous implementation was that it only increments `i`, that is enumerating the columns of a row in the SparseMatrix, when the row-index of the vector matches the column-index of the SparseMatrix. In cases where a particular row of the SparseMatrix has non-zero values at column-indices lower than corresponding non-zero row-indices of the SparseVector, the non-zero values of the SparseVector are enumerated without ever matching the column-index at index `i` and the remaining column-indices i+1,...,indEnd-1 are never attempted. The test cases in this PR illustrate this issue.
## How was this patch tested?
I have run the specific `gemv` tests in both mllib-local and mllib. I am currently still running `./dev/run-tests`.
## ___
As per instructions, I hereby state that this is my original work and that I license the work to the project (Apache Spark) under the project's open source license.
Mentioning dbtsai, viirya and brkyvz whom I can see have worked/authored on these parts before.
Author: Bjarne Fruergaard <bwahlgreen@gmail.com>
Closes #15296 from bwahlgreen/bugfix-spark-17721.
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 6a85608706..0cd68a633c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -637,12 +637,16 @@ private[spark] object BLAS extends Serializable with Logging { val indEnd = Arows(rowCounter + 1) var sum = 0.0 var k = 0 - while (k < xNnz && i < indEnd) { + while (i < indEnd && k < xNnz) { if (xIndices(k) == Acols(i)) { sum += Avals(i) * xValues(k) + k += 1 + i += 1 + } else if (xIndices(k) < Acols(i)) { + k += 1 + } else { i += 1 } - k += 1 } yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) rowCounter += 1 |