aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-08-30 12:21:15 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-30 12:21:15 -0700
commit8d2ab75d3b71b632f2394f2453af32f417cb45e5 (patch)
tree54dec9190cc9748571e99c4400cf8d5212ef5b5e /mllib
parent1bfd9347822df65e76201c4c471a26488d722319 (diff)
downloadspark-8d2ab75d3b71b632f2394f2453af32f417cb45e5.tar.gz
spark-8d2ab75d3b71b632f2394f2453af32f417cb45e5.tar.bz2
spark-8d2ab75d3b71b632f2394f2453af32f417cb45e5.zip
[SPARK-10353] [MLLIB] BLAS gemm not scaling when beta = 0.0 for some subset of matrix multiplications
mengxr jkbradley rxin It would be great if this fix made it into RC3! Author: Burak Yavuz <brkyvz@gmail.com> Closes #8525 from brkyvz/blas-scaling.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala26
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala5
2 files changed, 15 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index bbbcc8436b..ab475af264 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -305,6 +305,8 @@ private[spark] object BLAS extends Serializable with Logging {
"The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.")
if (alpha == 0.0 && beta == 1.0) {
logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.")
+ } else if (alpha == 0.0) {
+ f2jBLAS.dscal(C.values.length, beta, C.values, 1)
} else {
A match {
case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C)
@@ -408,8 +410,8 @@ private[spark] object BLAS extends Serializable with Logging {
}
}
} else {
- // Scale matrix first if `beta` is not equal to 0.0
- if (beta != 0.0) {
+ // Scale matrix first if `beta` is not equal to 1.0
+ if (beta != 1.0) {
f2jBLAS.dscal(C.values.length, beta, C.values, 1)
}
// Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
@@ -470,8 +472,10 @@ private[spark] object BLAS extends Serializable with Logging {
s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}")
require(A.numRows == y.size,
s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}")
- if (alpha == 0.0) {
- logDebug("gemv: alpha is equal to 0. Returning y.")
+ if (alpha == 0.0 && beta == 1.0) {
+ logDebug("gemv: alpha is equal to 0 and beta is equal to 1. Returning y.")
+ } else if (alpha == 0.0) {
+ scal(beta, y)
} else {
(A, x) match {
case (smA: SparseMatrix, dvx: DenseVector) =>
@@ -526,11 +530,6 @@ private[spark] object BLAS extends Serializable with Logging {
val xValues = x.values
val yValues = y.values
- if (alpha == 0.0) {
- scal(beta, y)
- return
- }
-
if (A.isTransposed) {
var rowCounterForA = 0
while (rowCounterForA < mA) {
@@ -581,11 +580,6 @@ private[spark] object BLAS extends Serializable with Logging {
val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs
val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices
- if (alpha == 0.0) {
- scal(beta, y)
- return
- }
-
if (A.isTransposed) {
var rowCounter = 0
while (rowCounter < mA) {
@@ -604,7 +598,7 @@ private[spark] object BLAS extends Serializable with Logging {
rowCounter += 1
}
} else {
- scal(beta, y)
+ if (beta != 1.0) scal(beta, y)
var colCounterForA = 0
var k = 0
@@ -659,7 +653,7 @@ private[spark] object BLAS extends Serializable with Logging {
rowCounter += 1
}
} else {
- scal(beta, y)
+ if (beta != 1.0) scal(beta, y)
// Perform matrix-vector multiplication and add to y
var colCounterForA = 0
while (colCounterForA < nA) {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index d119e0b50a..8db5c8424a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -204,6 +204,7 @@ class BLASSuite extends SparkFunSuite {
val C14 = C1.copy
val C15 = C1.copy
val C16 = C1.copy
+ val C17 = C1.copy
val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0))
@@ -217,6 +218,10 @@ class BLASSuite extends SparkFunSuite {
assert(C2 ~== expected2 absTol 1e-15)
assert(C3 ~== expected3 absTol 1e-15)
assert(C4 ~== expected3 absTol 1e-15)
+ gemm(1.0, dA, B, 0.0, C17)
+ assert(C17 ~== expected absTol 1e-15)
+ gemm(1.0, sA, B, 0.0, C17)
+ assert(C17 ~== expected absTol 1e-15)
withClue("columns of A don't match the rows of B") {
intercept[Exception] {