aboutsummaryrefslogtreecommitdiff
path: root/mllib-local/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib-local/src')
-rw-r--r--mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala8
-rw-r--r--mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala17
2 files changed, 23 insertions, 2 deletions
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
index 41b0c6c89a..4ca19f3387 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
@@ -638,12 +638,16 @@ private[spark] object BLAS extends Serializable {
val indEnd = Arows(rowCounter + 1)
var sum = 0.0
var k = 0
- while (k < xNnz && i < indEnd) {
+ while (i < indEnd && k < xNnz) {
if (xIndices(k) == Acols(i)) {
sum += Avals(i) * xValues(k)
+ k += 1
+ i += 1
+ } else if (xIndices(k) < Acols(i)) {
+ k += 1
+ } else {
i += 1
}
- k += 1
}
yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
rowCounter += 1
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala
index 8a9f49792c..6e72a5fff0 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala
@@ -392,6 +392,23 @@ class BLASSuite extends SparkMLFunSuite {
}
}
+ val y17 = new DenseVector(Array(0.0, 0.0))
+ val y18 = y17.copy
+
+ val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0))
+ .transpose
+ val sA4 =
+ new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0))
+ val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0))
+
+ val expected4 = new DenseVector(Array(5.0, 4.0))
+
+ gemv(1.0, sA3, sx3, 0.0, y17)
+ gemv(1.0, sA4, sx3, 0.0, y18)
+
+ assert(y17 ~== expected4 absTol 1e-15)
+ assert(y18 ~== expected4 absTol 1e-15)
+
val dAT =
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
val sAT =