aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-05-18 21:32:36 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-18 21:32:36 -0700
commitd03638cc2d414cee9ac7481084672e454495dfc1 (patch)
tree13f728f288fbaee293828269efab59507ef5c859
parent3a6003866ade45974b43a9e785ec35fb76a32b99 (diff)
downloadspark-d03638cc2d414cee9ac7481084672e454495dfc1.tar.gz
spark-d03638cc2d414cee9ac7481084672e454495dfc1.tar.bz2
spark-d03638cc2d414cee9ac7481084672e454495dfc1.zip
[SPARK-7681] [MLLIB] Add SparseVector support for gemv
JIRA: https://issues.apache.org/jira/browse/SPARK-7681 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #6209 from viirya/sparsevector_gemv and squashes the following commits: ce0bb8b [Liang-Chi Hsieh] Still need to scal y when beta is 0.0 because it clears out y. b890e63 [Liang-Chi Hsieh] Do not delete multiply for DenseVector. 57a8c1e [Liang-Chi Hsieh] Add MimaExcludes for v1.4. 458d1ae [Liang-Chi Hsieh] List DenseMatrix.multiply and SparseMatrix.multiply to MimaExcludes too. 054f05d [Liang-Chi Hsieh] Fix scala style. 410381a [Liang-Chi Hsieh] Address comments. Make Matrix.multiply more generalized. 4616696 [Liang-Chi Hsieh] Add support for SparseVector with SparseMatrix. 5d6d07a [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into sparsevector_gemv c069507 [Liang-Chi Hsieh] Add SparseVector support for gemv with DenseMatrix.
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala152
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala96
-rw-r--r--project/MimaExcludes.scala18
4 files changed, 240 insertions, 33 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 87052e1ba8..ec38529cf8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -463,7 +463,7 @@ private[spark] object BLAS extends Serializable with Logging {
def gemv(
alpha: Double,
A: Matrix,
- x: DenseVector,
+ x: Vector,
beta: Double,
y: DenseVector): Unit = {
require(A.numCols == x.size,
@@ -473,44 +473,169 @@ private[spark] object BLAS extends Serializable with Logging {
if (alpha == 0.0) {
logDebug("gemv: alpha is equal to 0. Returning y.")
} else {
- A match {
- case sparse: SparseMatrix =>
- gemv(alpha, sparse, x, beta, y)
- case dense: DenseMatrix =>
- gemv(alpha, dense, x, beta, y)
+ (A, x) match {
+ case (smA: SparseMatrix, dvx: DenseVector) =>
+ gemv(alpha, smA, dvx, beta, y)
+ case (smA: SparseMatrix, svx: SparseVector) =>
+ gemv(alpha, smA, svx, beta, y)
+ case (dmA: DenseMatrix, dvx: DenseVector) =>
+ gemv(alpha, dmA, dvx, beta, y)
+ case (dmA: DenseMatrix, svx: SparseVector) =>
+ gemv(alpha, dmA, svx, beta, y)
case _ =>
- throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.")
+ throw new IllegalArgumentException(s"gemv doesn't support running on matrix type " +
+ s"${A.getClass} and vector type ${x.getClass}.")
}
}
}
/**
* y := alpha * A * x + beta * y
- * For `DenseMatrix` A.
+ * For `DenseMatrix` A and `DenseVector` x.
*/
private def gemv(
alpha: Double,
A: DenseMatrix,
x: DenseVector,
beta: Double,
- y: DenseVector): Unit = {
+ y: DenseVector): Unit = {
val tStrA = if (A.isTransposed) "T" else "N"
val mA = if (!A.isTransposed) A.numRows else A.numCols
val nA = if (!A.isTransposed) A.numCols else A.numRows
nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta,
y.values, 1)
}
+
+ /**
+ * y := alpha * A * x + beta * y
+ * For `DenseMatrix` A and `SparseVector` x.
+ */
+ private def gemv(
+ alpha: Double,
+ A: DenseMatrix,
+ x: SparseVector,
+ beta: Double,
+ y: DenseVector): Unit = {
+ val mA: Int = A.numRows
+ val nA: Int = A.numCols
+
+ val Avals = A.values
+
+ val xIndices = x.indices
+ val xNnz = xIndices.length
+ val xValues = x.values
+ val yValues = y.values
+ if (alpha == 0.0) {
+ scal(beta, y)
+ return
+ }
+
+ if (A.isTransposed) {
+ var rowCounterForA = 0
+ while (rowCounterForA < mA) {
+ var sum = 0.0
+ var k = 0
+ while (k < xNnz) {
+ sum += xValues(k) * Avals(xIndices(k) + rowCounterForA * nA)
+ k += 1
+ }
+ yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA)
+ rowCounterForA += 1
+ }
+ } else {
+ var rowCounterForA = 0
+ while (rowCounterForA < mA) {
+ var sum = 0.0
+ var k = 0
+ while (k < xNnz) {
+ sum += xValues(k) * Avals(xIndices(k) * mA + rowCounterForA)
+ k += 1
+ }
+ yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA)
+ rowCounterForA += 1
+ }
+ }
+ }
+
/**
* y := alpha * A * x + beta * y
- * For `SparseMatrix` A.
+ * For `SparseMatrix` A and `SparseVector` x.
+ */
+ private def gemv(
+ alpha: Double,
+ A: SparseMatrix,
+ x: SparseVector,
+ beta: Double,
+ y: DenseVector): Unit = {
+ val xValues = x.values
+ val xIndices = x.indices
+ val xNnz = xIndices.length
+
+ val yValues = y.values
+
+ val mA: Int = A.numRows
+ val nA: Int = A.numCols
+
+ val Avals = A.values
+ val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs
+ val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices
+
+ if (alpha == 0.0) {
+ scal(beta, y)
+ return
+ }
+
+ if (A.isTransposed) {
+ var rowCounter = 0
+ while (rowCounter < mA) {
+ var i = Arows(rowCounter)
+ val indEnd = Arows(rowCounter + 1)
+ var sum = 0.0
+ var k = 0
+ while (k < xNnz && i < indEnd) {
+ if (xIndices(k) == Acols(i)) {
+ sum += Avals(i) * xValues(k)
+ i += 1
+ }
+ k += 1
+ }
+ yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
+ rowCounter += 1
+ }
+ } else {
+ scal(beta, y)
+
+ var colCounterForA = 0
+ var k = 0
+ while (colCounterForA < nA && k < xNnz) {
+ if (xIndices(k) == colCounterForA) {
+ var i = Acols(colCounterForA)
+ val indEnd = Acols(colCounterForA + 1)
+
+ val xTemp = xValues(k) * alpha
+ while (i < indEnd) {
+ val rowIndex = Arows(i)
+ yValues(Arows(i)) += Avals(i) * xTemp
+ i += 1
+ }
+ k += 1
+ }
+ colCounterForA += 1
+ }
+ }
+ }
+
+ /**
+ * y := alpha * A * x + beta * y
+ * For `SparseMatrix` A and `DenseVector` x.
*/
private def gemv(
alpha: Double,
A: SparseMatrix,
x: DenseVector,
beta: Double,
- y: DenseVector): Unit = {
+ y: DenseVector): Unit = {
val xValues = x.values
val yValues = y.values
val mA: Int = A.numRows
@@ -534,10 +659,7 @@ private[spark] object BLAS extends Serializable with Logging {
rowCounter += 1
}
} else {
- // Scale vector first if `beta` is not equal to 0.0
- if (beta != 0.0) {
- scal(beta, y)
- }
+ scal(beta, y)
// Perform matrix-vector multiplication and add to y
var colCounterForA = 0
while (colCounterForA < nA) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index a609674df6..9584da8e3a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -77,8 +77,13 @@ sealed trait Matrix extends Serializable {
C
}
- /** Convenience method for `Matrix`-`DenseVector` multiplication. */
+ /** Convenience method for `Matrix`-`DenseVector` multiplication. For binary compatibility. */
def multiply(y: DenseVector): DenseVector = {
+ multiply(y.asInstanceOf[Vector])
+ }
+
+ /** Convenience method for `Matrix`-`Vector` multiplication. */
+ def multiply(y: Vector): DenseVector = {
val output = new DenseVector(new Array[Double](numRows))
BLAS.gemv(1.0, this, y, 0.0, output)
output
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index 002cb25386..64ecd12ea7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -256,42 +256,108 @@ class BLASSuite extends FunSuite {
val dA =
new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0))
val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0))
-
- val x = new DenseVector(Array(1.0, 2.0, 3.0))
+
+ val dA2 =
+ new DenseMatrix(4, 3, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0), true)
+ val sA2 =
+ new SparseMatrix(4, 3, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0),
+ true)
+
+ val dx = new DenseVector(Array(1.0, 2.0, 3.0))
+ val sx = dx.toSparse
val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0))
- assert(dA.multiply(x) ~== expected absTol 1e-15)
- assert(sA.multiply(x) ~== expected absTol 1e-15)
-
+ assert(dA.multiply(dx) ~== expected absTol 1e-15)
+ assert(sA.multiply(dx) ~== expected absTol 1e-15)
+ assert(dA.multiply(sx) ~== expected absTol 1e-15)
+ assert(sA.multiply(sx) ~== expected absTol 1e-15)
+
val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
val y2 = y1.copy
val y3 = y1.copy
val y4 = y1.copy
+ val y5 = y1.copy
+ val y6 = y1.copy
+ val y7 = y1.copy
+ val y8 = y1.copy
+ val y9 = y1.copy
+ val y10 = y1.copy
+ val y11 = y1.copy
+ val y12 = y1.copy
+ val y13 = y1.copy
+ val y14 = y1.copy
+ val y15 = y1.copy
+ val y16 = y1.copy
+
val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0))
val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0))
- gemv(1.0, dA, x, 2.0, y1)
- gemv(1.0, sA, x, 2.0, y2)
- gemv(2.0, dA, x, 2.0, y3)
- gemv(2.0, sA, x, 2.0, y4)
+ gemv(1.0, dA, dx, 2.0, y1)
+ gemv(1.0, sA, dx, 2.0, y2)
+ gemv(1.0, dA, sx, 2.0, y3)
+ gemv(1.0, sA, sx, 2.0, y4)
+
+ gemv(1.0, dA2, dx, 2.0, y5)
+ gemv(1.0, sA2, dx, 2.0, y6)
+ gemv(1.0, dA2, sx, 2.0, y7)
+ gemv(1.0, sA2, sx, 2.0, y8)
+
+ gemv(2.0, dA, dx, 2.0, y9)
+ gemv(2.0, sA, dx, 2.0, y10)
+ gemv(2.0, dA, sx, 2.0, y11)
+ gemv(2.0, sA, sx, 2.0, y12)
+
+ gemv(2.0, dA2, dx, 2.0, y13)
+ gemv(2.0, sA2, dx, 2.0, y14)
+ gemv(2.0, dA2, sx, 2.0, y15)
+ gemv(2.0, sA2, sx, 2.0, y16)
+
assert(y1 ~== expected2 absTol 1e-15)
assert(y2 ~== expected2 absTol 1e-15)
- assert(y3 ~== expected3 absTol 1e-15)
- assert(y4 ~== expected3 absTol 1e-15)
+ assert(y3 ~== expected2 absTol 1e-15)
+ assert(y4 ~== expected2 absTol 1e-15)
+
+ assert(y5 ~== expected2 absTol 1e-15)
+ assert(y6 ~== expected2 absTol 1e-15)
+ assert(y7 ~== expected2 absTol 1e-15)
+ assert(y8 ~== expected2 absTol 1e-15)
+
+ assert(y9 ~== expected3 absTol 1e-15)
+ assert(y10 ~== expected3 absTol 1e-15)
+ assert(y11 ~== expected3 absTol 1e-15)
+ assert(y12 ~== expected3 absTol 1e-15)
+
+ assert(y13 ~== expected3 absTol 1e-15)
+ assert(y14 ~== expected3 absTol 1e-15)
+ assert(y15 ~== expected3 absTol 1e-15)
+ assert(y16 ~== expected3 absTol 1e-15)
+
withClue("columns of A don't match the rows of B") {
intercept[Exception] {
- gemv(1.0, dA.transpose, x, 2.0, y1)
+ gemv(1.0, dA.transpose, dx, 2.0, y1)
+ }
+ intercept[Exception] {
+ gemv(1.0, sA.transpose, dx, 2.0, y1)
+ }
+ intercept[Exception] {
+ gemv(1.0, dA.transpose, sx, 2.0, y1)
+ }
+ intercept[Exception] {
+ gemv(1.0, sA.transpose, sx, 2.0, y1)
}
}
+
val dAT =
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
val sAT =
new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0))
-
+
val dATT = dAT.transpose
val sATT = sAT.transpose
- assert(dATT.multiply(x) ~== expected absTol 1e-15)
- assert(sATT.multiply(x) ~== expected absTol 1e-15)
+ assert(dATT.multiply(dx) ~== expected absTol 1e-15)
+ assert(sATT.multiply(dx) ~== expected absTol 1e-15)
+ assert(dATT.multiply(sx) ~== expected absTol 1e-15)
+ assert(sATT.multiply(sx) ~== expected absTol 1e-15)
}
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 513bbaf98d..f8d0160f64 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -87,7 +87,14 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Vector.toSparse"),
ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.mllib.linalg.Vector.numActives")
+ "org.apache.spark.mllib.linalg.Vector.numActives"),
+ // SPARK-7681 add SparseVector support for gemv
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrix.multiply"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.DenseMatrix.multiply"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.SparseMatrix.multiply")
) ++ Seq(
// Execution should never be included as its always internal.
MimaBuild.excludeSparkPackage("sql.execution"),
@@ -180,7 +187,14 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Matrix.isTransposed"),
ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.mllib.linalg.Matrix.foreachActive")
+ "org.apache.spark.mllib.linalg.Matrix.foreachActive"),
+ // SPARK-7681 add SparseVector support for gemv
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrix.multiply"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.DenseMatrix.multiply"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.SparseMatrix.multiply")
) ++ Seq(
// SPARK-5540
ProblemFilters.exclude[MissingMethodProblem](