diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala | 4 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala | 16 |
2 files changed, 18 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 3523f18043..9029093e0f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -303,8 +303,8 @@ private[spark] object BLAS extends Serializable with Logging { C: DenseMatrix): Unit = { require(!C.isTransposed, "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.") - if (alpha == 0.0) { - logDebug("gemm: alpha is equal to 0. Returning C.") + if (alpha == 0.0 && beta == 1.0) { + logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.") } else { A match { case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index b0f3f71113..d119e0b50a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -200,8 +200,14 @@ class BLASSuite extends SparkFunSuite { val C10 = C1.copy val C11 = C1.copy val C12 = C1.copy + val C13 = C1.copy + val C14 = C1.copy + val C15 = C1.copy + val C16 = C1.copy val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) + val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0)) + val expected5 = C1.copy gemm(1.0, dA, B, 2.0, C1) gemm(1.0, sA, B, 2.0, C2) @@ -248,6 +254,16 @@ class BLASSuite extends SparkFunSuite { assert(C10 ~== expected2 absTol 1e-15) assert(C11 ~== expected3 absTol 1e-15) assert(C12 ~== expected3 absTol 1e-15) + + gemm(0, dA, B, 5, C13) + gemm(0, sA, B, 5, C14) + gemm(0, dA, B, 1, C15) + gemm(0, sA, B, 1, C16) + assert(C13 ~== expected4 absTol 1e-15) + assert(C14 ~== expected4 absTol 1e-15) + assert(C15 ~== expected5 absTol 1e-15) + assert(C16 ~== expected5 absTol 1e-15) + } test("gemv") { |