aboutsummaryrefslogtreecommitdiff
path: root/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala')
-rw-r--r--mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala45
1 files changed, 45 insertions, 0 deletions
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala
index 6e72a5fff0..877ac68983 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala
@@ -422,4 +422,49 @@ class BLASSuite extends SparkMLFunSuite {
assert(dATT.multiply(sx) ~== expected absTol 1e-15)
assert(sATT.multiply(sx) ~== expected absTol 1e-15)
}
+
+ test("spmv") {
+ /*
+ A = [[3.0, -2.0, 2.0, -4.0],
+ [-2.0, -8.0, 4.0, 7.0],
+ [2.0, 4.0, -3.0, -3.0],
+ [-4.0, 7.0, -3.0, 0.0]]
+ x = [5.0, 2.0, -1.0, -9.0]
+ Ax = [ 45., -93., 48., -3.]
+ */
+ val A = new DenseVector(Array(3.0, -2.0, -8.0, 2.0, 4.0, -3.0, -4.0, 7.0, -3.0, 0.0))
+ val x = new DenseVector(Array(5.0, 2.0, -1.0, -9.0))
+ val n = 4
+
+ val y1 = new DenseVector(Array(-3.0, 6.0, -8.0, -3.0))
+ val y2 = y1.copy
+ val y3 = y1.copy
+ val y4 = y1.copy
+ val y5 = y1.copy
+ val y6 = y1.copy
+ val y7 = y1.copy
+
+ val expected1 = new DenseVector(Array(42.0, -87.0, 40.0, -6.0))
+ val expected2 = new DenseVector(Array(19.5, -40.5, 16.0, -4.5))
+ val expected3 = new DenseVector(Array(-25.5, 52.5, -32.0, -1.5))
+ val expected4 = new DenseVector(Array(-3.0, 6.0, -8.0, -3.0))
+ val expected5 = new DenseVector(Array(43.5, -90.0, 44.0, -4.5))
+ val expected6 = new DenseVector(Array(46.5, -96.0, 52.0, -1.5))
+ val expected7 = new DenseVector(Array(45.0, -93.0, 48.0, -3.0))
+
+ dspmv(n, 1.0, A, x, 1.0, y1)
+ dspmv(n, 0.5, A, x, 1.0, y2)
+ dspmv(n, -0.5, A, x, 1.0, y3)
+ dspmv(n, 0.0, A, x, 1.0, y4)
+ dspmv(n, 1.0, A, x, 0.5, y5)
+ dspmv(n, 1.0, A, x, -0.5, y6)
+ dspmv(n, 1.0, A, x, 0.0, y7)
+ assert(y1 ~== expected1 absTol 1e-8)
+ assert(y2 ~== expected2 absTol 1e-8)
+ assert(y3 ~== expected3 absTol 1e-8)
+ assert(y4 ~== expected4 absTol 1e-8)
+ assert(y5 ~== expected5 absTol 1e-8)
+ assert(y6 ~== expected6 absTol 1e-8)
+ assert(y7 ~== expected7 absTol 1e-8)
+ }
}