aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMeihua Wu <meihuawu@umich.edu>2015-07-20 17:03:46 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-20 17:03:46 -0700
commitff3c72dbafa16c6158fc36619f3c38344c452ba0 (patch)
tree1166d9818bf6cd316bb97b2e73b1dd17475b4ae4
parenta5d05819afcc9b19aeae4817d842205f32b34335 (diff)
downloadspark-ff3c72dbafa16c6158fc36619f3c38344c452ba0.tar.gz
spark-ff3c72dbafa16c6158fc36619f3c38344c452ba0.tar.bz2
spark-ff3c72dbafa16c6158fc36619f3c38344c452ba0.zip
[SPARK-9175] [MLLIB] BLAS.gemm fails to update matrix C when alpha==0 and beta!=1
Fix BLAS.gemm to update matrix C when alpha==0 and beta!=1 Also include unit tests to verify the fix. mengxr brkyvz Author: Meihua Wu <meihuawu@umich.edu> Closes #7503 from rotationsymmetry/fix_BLAS_gemm and squashes the following commits: fce199c [Meihua Wu] Fix BLAS.gemm to update C when alpha==0 and beta!=1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala16
2 files changed, 18 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 3523f18043..9029093e0f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -303,8 +303,8 @@ private[spark] object BLAS extends Serializable with Logging {
C: DenseMatrix): Unit = {
require(!C.isTransposed,
"The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.")
- if (alpha == 0.0) {
- logDebug("gemm: alpha is equal to 0. Returning C.")
+ if (alpha == 0.0 && beta == 1.0) {
+ logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.")
} else {
A match {
case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index b0f3f71113..d119e0b50a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -200,8 +200,14 @@ class BLASSuite extends SparkFunSuite {
val C10 = C1.copy
val C11 = C1.copy
val C12 = C1.copy
+ val C13 = C1.copy
+ val C14 = C1.copy
+ val C15 = C1.copy
+ val C16 = C1.copy
val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
+ val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0))
+ val expected5 = C1.copy
gemm(1.0, dA, B, 2.0, C1)
gemm(1.0, sA, B, 2.0, C2)
@@ -248,6 +254,16 @@ class BLASSuite extends SparkFunSuite {
assert(C10 ~== expected2 absTol 1e-15)
assert(C11 ~== expected3 absTol 1e-15)
assert(C12 ~== expected3 absTol 1e-15)
+
+ gemm(0, dA, B, 5, C13)
+ gemm(0, sA, B, 5, C14)
+ gemm(0, dA, B, 1, C15)
+ gemm(0, sA, B, 1, C16)
+ assert(C13 ~== expected4 absTol 1e-15)
+ assert(C14 ~== expected4 absTol 1e-15)
+ assert(C15 ~== expected5 absTol 1e-15)
+ assert(C16 ~== expected5 absTol 1e-15)
+
}
test("gemv") {