From e76ef5cb8eed6b78fb722b3d6fbeb9466a0e3499 Mon Sep 17 00:00:00 2001 From: Burak Date: Thu, 18 Sep 2014 22:18:51 -0700 Subject: [SPARK-3418] Sparse Matrix support (CCS) and additional native BLAS operations added Local `SparseMatrix` support added in Compressed Column Storage (CCS) format in addition to Level-2 and Level-3 BLAS operations such as dgemv and dgemm respectively. BLAS doesn't support sparse matrix operations, therefore support for `SparseMatrix`-`DenseMatrix` multiplication and `SparseMatrix`-`DenseVector` implementations have been added. I will post performance comparisons in the comments momentarily. Author: Burak Closes #2294 from brkyvz/SPARK-3418 and squashes the following commits: 88814ed [Burak] Hopefully fixed MiMa this time 47e49d5 [Burak] really fixed MiMa issue f0bae57 [Burak] [SPARK-3418] Fixed MiMa compatibility issues (excluded from check) 4b7dbec [Burak] 9/17 comments addressed 7af2f83 [Burak] sealed traits Vector and Matrix d3a8a16 [Burak] [SPARK-3418] Squashed missing alpha bug. 421045f [Burak] [SPARK-3418] New code review comments addressed f35a161 [Burak] [SPARK-3418] Code review comments addressed and multiplication further optimized 2508577 [Burak] [SPARK-3418] Fixed one more style issue d16e8a0 [Burak] [SPARK-3418] Fixed style issues and added documentation for methods 204a3f7 [Burak] [SPARK-3418] Fixed failing Matrix unit test 6025297 [Burak] [SPARK-3418] Fixed Scala-style errors dc7be71 [Burak] [SPARK-3418][MLlib] Matrix unit tests expanded with indexing and updating d2d5851 [Burak] [SPARK-3418][MLlib] Sparse Matrix support and additional native BLAS operations added --- .../scala/org/apache/spark/mllib/linalg/BLAS.scala | 330 ++++++++++++++++++++- .../org/apache/spark/mllib/linalg/Matrices.scala | 232 ++++++++++++++- .../org/apache/spark/mllib/linalg/Vectors.scala | 2 +- .../org/apache/spark/mllib/linalg/BLASSuite.scala | 111 +++++++ .../mllib/linalg/BreezeMatrixConversionSuite.scala | 24 +- .../apache/spark/mllib/linalg/MatricesSuite.scala | 76 +++++ .../org/apache/spark/mllib/util/TestingUtils.scala | 65 +++- 7 files changed, 831 insertions(+), 9 deletions(-) (limited to 'mllib/src') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 70e23033c8..54ee930d61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -18,13 +18,17 @@ package org.apache.spark.mllib.linalg import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} +import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} + +import org.apache.spark.Logging /** * BLAS routines for MLlib's vectors and matrices. */ -private[mllib] object BLAS extends Serializable { +private[mllib] object BLAS extends Serializable with Logging { @transient private var _f2jBLAS: NetlibBLAS = _ + @transient private var _nativeBLAS: NetlibBLAS = _ // For level-1 routines, we use Java implementation. private def f2jBLAS: NetlibBLAS = { @@ -197,4 +201,328 @@ private[mllib] object BLAS extends Serializable { throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.") } } + + // For level-3 routines, we use the native BLAS. + private def nativeBLAS: NetlibBLAS = { + if (_nativeBLAS == null) { + _nativeBLAS = NativeBLAS + } + _nativeBLAS + } + + /** + * C := alpha * A * B + beta * C + * @param transA whether to use the transpose of matrix A (true), or A itself (false). + * @param transB whether to use the transpose of matrix B (true), or B itself (false). + * @param alpha a scalar to scale the multiplication A * B. + * @param A the matrix A that will be left multiplied to B. Size of m x k. + * @param B the matrix B that will be left multiplied by A. Size of k x n. + * @param beta a scalar that can be used to scale matrix C. + * @param C the resulting matrix C. Size of m x n. + */ + def gemm( + transA: Boolean, + transB: Boolean, + alpha: Double, + A: Matrix, + B: DenseMatrix, + beta: Double, + C: DenseMatrix): Unit = { + if (alpha == 0.0) { + logDebug("gemm: alpha is equal to 0. Returning C.") + } else { + A match { + case sparse: SparseMatrix => + gemm(transA, transB, alpha, sparse, B, beta, C) + case dense: DenseMatrix => + gemm(transA, transB, alpha, dense, B, beta, C) + case _ => + throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.") + } + } + } + + /** + * C := alpha * A * B + beta * C + * + * @param alpha a scalar to scale the multiplication A * B. + * @param A the matrix A that will be left multiplied to B. Size of m x k. + * @param B the matrix B that will be left multiplied by A. Size of k x n. + * @param beta a scalar that can be used to scale matrix C. + * @param C the resulting matrix C. Size of m x n. + */ + def gemm( + alpha: Double, + A: Matrix, + B: DenseMatrix, + beta: Double, + C: DenseMatrix): Unit = { + gemm(false, false, alpha, A, B, beta, C) + } + + /** + * C := alpha * A * B + beta * C + * For `DenseMatrix` A. + */ + private def gemm( + transA: Boolean, + transB: Boolean, + alpha: Double, + A: DenseMatrix, + B: DenseMatrix, + beta: Double, + C: DenseMatrix): Unit = { + val mA: Int = if (!transA) A.numRows else A.numCols + val nB: Int = if (!transB) B.numCols else B.numRows + val kA: Int = if (!transA) A.numCols else A.numRows + val kB: Int = if (!transB) B.numRows else B.numCols + val tAstr = if (!transA) "N" else "T" + val tBstr = if (!transB) "N" else "T" + + require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") + require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") + require(nB == C.numCols, + s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB") + + nativeBLAS.dgemm(tAstr, tBstr, mA, nB, kA, alpha, A.values, A.numRows, B.values, B.numRows, + beta, C.values, C.numRows) + } + + /** + * C := alpha * A * B + beta * C + * For `SparseMatrix` A. + */ + private def gemm( + transA: Boolean, + transB: Boolean, + alpha: Double, + A: SparseMatrix, + B: DenseMatrix, + beta: Double, + C: DenseMatrix): Unit = { + val mA: Int = if (!transA) A.numRows else A.numCols + val nB: Int = if (!transB) B.numCols else B.numRows + val kA: Int = if (!transA) A.numCols else A.numRows + val kB: Int = if (!transB) B.numRows else B.numCols + + require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") + require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") + require(nB == C.numCols, + s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB") + + val Avals = A.values + val Arows = if (!transA) A.rowIndices else A.colPtrs + val Acols = if (!transA) A.colPtrs else A.rowIndices + + // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices + if (transA){ + var colCounterForB = 0 + if (!transB) { // Expensive to put the check inside the loop + while (colCounterForB < nB) { + var rowCounterForA = 0 + val Cstart = colCounterForB * mA + val Bstart = colCounterForB * kA + while (rowCounterForA < mA) { + var i = Arows(rowCounterForA) + val indEnd = Arows(rowCounterForA + 1) + var sum = 0.0 + while (i < indEnd) { + sum += Avals(i) * B.values(Bstart + Acols(i)) + i += 1 + } + val Cindex = Cstart + rowCounterForA + C.values(Cindex) = beta * C.values(Cindex) + sum * alpha + rowCounterForA += 1 + } + colCounterForB += 1 + } + } else { + while (colCounterForB < nB) { + var rowCounter = 0 + val Cstart = colCounterForB * mA + while (rowCounter < mA) { + var i = Arows(rowCounter) + val indEnd = Arows(rowCounter + 1) + var sum = 0.0 + while (i < indEnd) { + sum += Avals(i) * B(colCounterForB, Acols(i)) + i += 1 + } + val Cindex = Cstart + rowCounter + C.values(Cindex) = beta * C.values(Cindex) + sum * alpha + rowCounter += 1 + } + colCounterForB += 1 + } + } + } else { + // Scale matrix first if `beta` is not equal to 0.0 + if (beta != 0.0){ + f2jBLAS.dscal(C.values.length, beta, C.values, 1) + } + // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of + // B, and added to C. + var colCounterForB = 0 // the column to be updated in C + if (!transB) { // Expensive to put the check inside the loop + while (colCounterForB < nB) { + var colCounterForA = 0 // The column of A to multiply with the row of B + val Bstart = colCounterForB * kB + val Cstart = colCounterForB * mA + while (colCounterForA < kA) { + var i = Acols(colCounterForA) + val indEnd = Acols(colCounterForA + 1) + val Bval = B.values(Bstart + colCounterForA) * alpha + while (i < indEnd){ + C.values(Cstart + Arows(i)) += Avals(i) * Bval + i += 1 + } + colCounterForA += 1 + } + colCounterForB += 1 + } + } else { + while (colCounterForB < nB) { + var colCounterForA = 0 // The column of A to multiply with the row of B + val Cstart = colCounterForB * mA + while (colCounterForA < kA){ + var i = Acols(colCounterForA) + val indEnd = Acols(colCounterForA + 1) + val Bval = B(colCounterForB, colCounterForA) * alpha + while (i < indEnd){ + C.values(Cstart + Arows(i)) += Avals(i) * Bval + i += 1 + } + colCounterForA += 1 + } + colCounterForB += 1 + } + } + } + } + + /** + * y := alpha * A * x + beta * y + * @param trans whether to use the transpose of matrix A (true), or A itself (false). + * @param alpha a scalar to scale the multiplication A * x. + * @param A the matrix A that will be left multiplied to x. Size of m x n. + * @param x the vector x that will be left multiplied by A. Size of n x 1. + * @param beta a scalar that can be used to scale vector y. + * @param y the resulting vector y. Size of m x 1. + */ + def gemv( + trans: Boolean, + alpha: Double, + A: Matrix, + x: DenseVector, + beta: Double, + y: DenseVector): Unit = { + + val mA: Int = if (!trans) A.numRows else A.numCols + val nx: Int = x.size + val nA: Int = if (!trans) A.numCols else A.numRows + + require(nA == nx, s"The columns of A don't match the number of elements of x. A: $nA, x: $nx") + require(mA == y.size, + s"The rows of A don't match the number of elements of y. A: $mA, y:${y.size}}") + if (alpha == 0.0) { + logDebug("gemv: alpha is equal to 0. Returning y.") + } else { + A match { + case sparse: SparseMatrix => + gemv(trans, alpha, sparse, x, beta, y) + case dense: DenseMatrix => + gemv(trans, alpha, dense, x, beta, y) + case _ => + throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.") + } + } + } + + /** + * y := alpha * A * x + beta * y + * + * @param alpha a scalar to scale the multiplication A * x. + * @param A the matrix A that will be left multiplied to x. Size of m x n. + * @param x the vector x that will be left multiplied by A. Size of n x 1. + * @param beta a scalar that can be used to scale vector y. + * @param y the resulting vector y. Size of m x 1. + */ + def gemv( + alpha: Double, + A: Matrix, + x: DenseVector, + beta: Double, + y: DenseVector): Unit = { + gemv(false, alpha, A, x, beta, y) + } + + /** + * y := alpha * A * x + beta * y + * For `DenseMatrix` A. + */ + private def gemv( + trans: Boolean, + alpha: Double, + A: DenseMatrix, + x: DenseVector, + beta: Double, + y: DenseVector): Unit = { + val tStrA = if (!trans) "N" else "T" + nativeBLAS.dgemv(tStrA, A.numRows, A.numCols, alpha, A.values, A.numRows, x.values, 1, beta, + y.values, 1) + } + + /** + * y := alpha * A * x + beta * y + * For `SparseMatrix` A. + */ + private def gemv( + trans: Boolean, + alpha: Double, + A: SparseMatrix, + x: DenseVector, + beta: Double, + y: DenseVector): Unit = { + + val mA: Int = if(!trans) A.numRows else A.numCols + val nA: Int = if(!trans) A.numCols else A.numRows + + val Avals = A.values + val Arows = if (!trans) A.rowIndices else A.colPtrs + val Acols = if (!trans) A.colPtrs else A.rowIndices + + // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices + if (trans){ + var rowCounter = 0 + while (rowCounter < mA){ + var i = Arows(rowCounter) + val indEnd = Arows(rowCounter + 1) + var sum = 0.0 + while(i < indEnd){ + sum += Avals(i) * x.values(Acols(i)) + i += 1 + } + y.values(rowCounter) = beta * y.values(rowCounter) + sum * alpha + rowCounter += 1 + } + } else { + // Scale vector first if `beta` is not equal to 0.0 + if (beta != 0.0){ + scal(beta, y) + } + // Perform matrix-vector multiplication and add to y + var colCounterForA = 0 + while (colCounterForA < nA){ + var i = Acols(colCounterForA) + val indEnd = Acols(colCounterForA + 1) + val xVal = x.values(colCounterForA) * alpha + while (i < indEnd){ + val rowIndex = Arows(i) + y.values(rowIndex) += Avals(i) * xVal + i += 1 + } + colCounterForA += 1 + } + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index b11ba5d30f..5711532abc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -17,12 +17,16 @@ package org.apache.spark.mllib.linalg -import breeze.linalg.{Matrix => BM, DenseMatrix => BDM} +import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM} + +import org.apache.spark.util.random.XORShiftRandom + +import java.util.Arrays /** * Trait for a local matrix. */ -trait Matrix extends Serializable { +sealed trait Matrix extends Serializable { /** Number of rows. */ def numRows: Int @@ -37,8 +41,46 @@ trait Matrix extends Serializable { private[mllib] def toBreeze: BM[Double] /** Gets the (i, j)-th element. */ - private[mllib] def apply(i: Int, j: Int): Double = toBreeze(i, j) + private[mllib] def apply(i: Int, j: Int): Double + + /** Return the index for the (i, j)-th element in the backing array. */ + private[mllib] def index(i: Int, j: Int): Int + + /** Update element at (i, j) */ + private[mllib] def update(i: Int, j: Int, v: Double): Unit + + /** Get a deep copy of the matrix. */ + def copy: Matrix + /** Convenience method for `Matrix`-`DenseMatrix` multiplication. */ + def multiply(y: DenseMatrix): DenseMatrix = { + val C: DenseMatrix = Matrices.zeros(numRows, y.numCols).asInstanceOf[DenseMatrix] + BLAS.gemm(false, false, 1.0, this, y, 0.0, C) + C + } + + /** Convenience method for `Matrix`-`DenseVector` multiplication. */ + def multiply(y: DenseVector): DenseVector = { + val output = new DenseVector(new Array[Double](numRows)) + BLAS.gemv(1.0, this, y, 0.0, output) + output + } + + /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */ + def transposeMultiply(y: DenseMatrix): DenseMatrix = { + val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix] + BLAS.gemm(true, false, 1.0, this, y, 0.0, C) + C + } + + /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */ + def transposeMultiply(y: DenseVector): DenseVector = { + val output = new DenseVector(new Array[Double](numCols)) + BLAS.gemv(true, 1.0, this, y, 0.0, output) + output + } + + /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() } @@ -59,11 +101,98 @@ trait Matrix extends Serializable { */ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) extends Matrix { - require(values.length == numRows * numCols) + require(values.length == numRows * numCols, "The number of values supplied doesn't match the " + + s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}") override def toArray: Array[Double] = values - private[mllib] override def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values) + private[mllib] def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values) + + private[mllib] def apply(i: Int): Double = values(i) + + private[mllib] def apply(i: Int, j: Int): Double = values(index(i, j)) + + private[mllib] def index(i: Int, j: Int): Int = i + numRows * j + + private[mllib] def update(i: Int, j: Int, v: Double): Unit = { + values(index(i, j)) = v + } + + override def copy = new DenseMatrix(numRows, numCols, values.clone()) +} + +/** + * Column-majored sparse matrix. + * The entry values are stored in Compressed Sparse Column (CSC) format. + * For example, the following matrix + * {{{ + * 1.0 0.0 4.0 + * 0.0 3.0 5.0 + * 2.0 0.0 6.0 + * }}} + * is stored as `values: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]`, + * `rowIndices=[0, 2, 1, 0, 1, 2]`, `colPointers=[0, 2, 3, 6]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param colPtrs the index corresponding to the start of a new column + * @param rowIndices the row index of the entry. They must be in strictly increasing order for each + * column + * @param values non-zero matrix entries in column major + */ +class SparseMatrix( + val numRows: Int, + val numCols: Int, + val colPtrs: Array[Int], + val rowIndices: Array[Int], + val values: Array[Double]) extends Matrix { + + require(values.length == rowIndices.length, "The number of row indices and values don't match! " + + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") + require(colPtrs.length == numCols + 1, "The length of the column indices should be the " + + s"number of columns + 1. Currently, colPointers.length: ${colPtrs.length}, " + + s"numCols: $numCols") + + override def toArray: Array[Double] = { + val arr = new Array[Double](numRows * numCols) + var j = 0 + while (j < numCols) { + var i = colPtrs(j) + val indEnd = colPtrs(j + 1) + val offset = j * numRows + while (i < indEnd) { + val rowIndex = rowIndices(i) + arr(offset + rowIndex) = values(i) + i += 1 + } + j += 1 + } + arr + } + + private[mllib] def toBreeze: BM[Double] = + new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) + + private[mllib] def apply(i: Int, j: Int): Double = { + val ind = index(i, j) + if (ind < 0) 0.0 else values(ind) + } + + private[mllib] def index(i: Int, j: Int): Int = { + Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) + } + + private[mllib] def update(i: Int, j: Int, v: Double): Unit = { + val ind = index(i, j) + if (ind == -1){ + throw new NoSuchElementException("The given row and column indices correspond to a zero " + + "value. Only non-zero elements in Sparse Matrices can be updated.") + } else { + values(index(i, j)) = v + } + } + + override def copy = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) } /** @@ -82,6 +211,24 @@ object Matrices { new DenseMatrix(numRows, numCols, values) } + /** + * Creates a column-majored sparse matrix in Compressed Sparse Column (CSC) format. + * + * @param numRows number of rows + * @param numCols number of columns + * @param colPtrs the index corresponding to the start of a new column + * @param rowIndices the row index of the entry + * @param values non-zero matrix entries in column major + */ + def sparse( + numRows: Int, + numCols: Int, + colPtrs: Array[Int], + rowIndices: Array[Int], + values: Array[Double]): Matrix = { + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values) + } + /** * Creates a Matrix instance from a breeze matrix. * @param breeze a breeze matrix @@ -93,9 +240,84 @@ object Matrices { require(dm.majorStride == dm.rows, "Do not support stride size different from the number of rows.") new DenseMatrix(dm.rows, dm.cols, dm.data) + case sm: BSM[Double] => + new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data) case _ => throw new UnsupportedOperationException( s"Do not support conversion from type ${breeze.getClass.getName}.") } } + + /** + * Generate a `DenseMatrix` consisting of zeros. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros + */ + def zeros(numRows: Int, numCols: Int): Matrix = + new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols)) + + /** + * Generate a `DenseMatrix` consisting of ones. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones + */ + def ones(numRows: Int, numCols: Int): Matrix = + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0)) + + /** + * Generate an Identity Matrix in `DenseMatrix` format. + * @param n number of rows and columns of the matrix + * @return `DenseMatrix` with size `n` x `n` and values of ones on the diagonal + */ + def eye(n: Int): Matrix = { + val identity = Matrices.zeros(n, n) + var i = 0 + while (i < n){ + identity.update(i, i, 1.0) + i += 1 + } + identity + } + + /** + * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) + */ + def rand(numRows: Int, numCols: Int): Matrix = { + val rand = new XORShiftRandom + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextDouble())) + } + + /** + * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) + */ + def randn(numRows: Int, numCols: Int): Matrix = { + val rand = new XORShiftRandom + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextGaussian())) + } + + /** + * Generate a diagonal matrix in `DenseMatrix` format from the supplied values. + * @param vector a `Vector` tat will form the values on the diagonal of the matrix + * @return Square `DenseMatrix` with size `values.length` x `values.length` and `values` + * on the diagonal + */ + def diag(vector: Vector): Matrix = { + val n = vector.size + val matrix = Matrices.eye(n) + val values = vector.toArray + var i = 0 + while (i < n) { + matrix.update(i, i, values(i)) + i += 1 + } + matrix + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index a45781d12e..6af225b7f4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -33,7 +33,7 @@ import org.apache.spark.SparkException * * Note: Users should not implement this interface. */ -trait Vector extends Serializable { +sealed trait Vector extends Serializable { /** * Size of the vector. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 1952e6734e..5d70c914f1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -126,4 +126,115 @@ class BLASSuite extends FunSuite { } } } + + test("gemm") { + + val dA = + new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) + val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) + + val B = new DenseMatrix(3, 2, Array(1.0, 0.0, 0.0, 0.0, 2.0, 1.0)) + val expected = new DenseMatrix(4, 2, Array(0.0, 1.0, 0.0, 0.0, 4.0, 0.0, 2.0, 3.0)) + + assert(dA multiply B ~== expected absTol 1e-15) + assert(sA multiply B ~== expected absTol 1e-15) + + val C1 = new DenseMatrix(4, 2, Array(1.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0)) + val C2 = C1.copy + val C3 = C1.copy + val C4 = C1.copy + val C5 = C1.copy + val C6 = C1.copy + val C7 = C1.copy + val C8 = C1.copy + val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) + val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) + + gemm(1.0, dA, B, 2.0, C1) + gemm(1.0, sA, B, 2.0, C2) + gemm(2.0, dA, B, 2.0, C3) + gemm(2.0, sA, B, 2.0, C4) + assert(C1 ~== expected2 absTol 1e-15) + assert(C2 ~== expected2 absTol 1e-15) + assert(C3 ~== expected3 absTol 1e-15) + assert(C4 ~== expected3 absTol 1e-15) + + withClue("columns of A don't match the rows of B") { + intercept[Exception] { + gemm(true, false, 1.0, dA, B, 2.0, C1) + } + } + + val dAT = + new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) + val sAT = + new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) + + assert(dAT transposeMultiply B ~== expected absTol 1e-15) + assert(sAT transposeMultiply B ~== expected absTol 1e-15) + + gemm(true, false, 1.0, dAT, B, 2.0, C5) + gemm(true, false, 1.0, sAT, B, 2.0, C6) + gemm(true, false, 2.0, dAT, B, 2.0, C7) + gemm(true, false, 2.0, sAT, B, 2.0, C8) + assert(C5 ~== expected2 absTol 1e-15) + assert(C6 ~== expected2 absTol 1e-15) + assert(C7 ~== expected3 absTol 1e-15) + assert(C8 ~== expected3 absTol 1e-15) + } + + test("gemv") { + + val dA = + new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) + val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) + + val x = new DenseVector(Array(1.0, 2.0, 3.0)) + val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) + + assert(dA multiply x ~== expected absTol 1e-15) + assert(sA multiply x ~== expected absTol 1e-15) + + val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) + val y2 = y1.copy + val y3 = y1.copy + val y4 = y1.copy + val y5 = y1.copy + val y6 = y1.copy + val y7 = y1.copy + val y8 = y1.copy + val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) + val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) + + gemv(1.0, dA, x, 2.0, y1) + gemv(1.0, sA, x, 2.0, y2) + gemv(2.0, dA, x, 2.0, y3) + gemv(2.0, sA, x, 2.0, y4) + assert(y1 ~== expected2 absTol 1e-15) + assert(y2 ~== expected2 absTol 1e-15) + assert(y3 ~== expected3 absTol 1e-15) + assert(y4 ~== expected3 absTol 1e-15) + withClue("columns of A don't match the rows of B") { + intercept[Exception] { + gemv(true, 1.0, dA, x, 2.0, y1) + } + } + + val dAT = + new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) + val sAT = + new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) + + assert(dAT transposeMultiply x ~== expected absTol 1e-15) + assert(sAT transposeMultiply x ~== expected absTol 1e-15) + + gemv(true, 1.0, dAT, x, 2.0, y5) + gemv(true, 1.0, sAT, x, 2.0, y6) + gemv(true, 2.0, dAT, x, 2.0, y7) + gemv(true, 2.0, sAT, x, 2.0, y8) + assert(y5 ~== expected2 absTol 1e-15) + assert(y6 ~== expected2 absTol 1e-15) + assert(y7 ~== expected3 absTol 1e-15) + assert(y8 ~== expected3 absTol 1e-15) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index 82d49c76ed..73a6d3a27d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg import org.scalatest.FunSuite -import breeze.linalg.{DenseMatrix => BDM} +import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM} class BreezeMatrixConversionSuite extends FunSuite { test("dense matrix to breeze") { @@ -37,4 +37,26 @@ class BreezeMatrixConversionSuite extends FunSuite { assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") } + + test("sparse matrix to breeze") { + val values = Array(1.0, 2.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(1, 2, 1, 2) + val mat = Matrices.sparse(3, 2, colPtrs, rowIndices, values) + val breeze = mat.toBreeze.asInstanceOf[BSM[Double]] + assert(breeze.rows === mat.numRows) + assert(breeze.cols === mat.numCols) + assert(breeze.data.eq(mat.asInstanceOf[SparseMatrix].values), "should not copy data") + } + + test("sparse breeze matrix to sparse matrix") { + val values = Array(1.0, 2.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(1, 2, 1, 2) + val breeze = new BSM[Double](values, 3, 2, colPtrs, rowIndices) + val mat = Matrices.fromBreeze(breeze).asInstanceOf[SparseMatrix] + assert(mat.numRows === breeze.rows) + assert(mat.numCols === breeze.cols) + assert(mat.values.eq(breeze.data), "should not copy data") + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 9c66b4db9f..5f8b8c4b72 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -36,4 +36,80 @@ class MatricesSuite extends FunSuite { Matrices.dense(3, 2, Array(0.0, 1.0, 2.0)) } } + + test("sparse matrix construction") { + val m = 3 + val n = 2 + val values = Array(1.0, 2.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(1, 2, 1, 2) + val mat = Matrices.sparse(m, n, colPtrs, rowIndices, values).asInstanceOf[SparseMatrix] + assert(mat.numRows === m) + assert(mat.numCols === n) + assert(mat.values.eq(values), "should not copy data") + assert(mat.colPtrs.eq(colPtrs), "should not copy data") + assert(mat.rowIndices.eq(rowIndices), "should not copy data") + } + + test("sparse matrix construction with wrong number of elements") { + intercept[IllegalArgumentException] { + Matrices.sparse(3, 2, Array(0, 1), Array(1, 2, 1), Array(0.0, 1.0, 2.0)) + } + + intercept[IllegalArgumentException] { + Matrices.sparse(3, 2, Array(0, 1, 2), Array(1, 2), Array(0.0, 1.0, 2.0)) + } + } + + test("matrix copies are deep copies") { + val m = 3 + val n = 2 + + val denseMat = Matrices.dense(m, n, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) + val denseCopy = denseMat.copy + + assert(!denseMat.toArray.eq(denseCopy.toArray)) + + val values = Array(1.0, 2.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(1, 2, 1, 2) + val sparseMat = Matrices.sparse(m, n, colPtrs, rowIndices, values) + val sparseCopy = sparseMat.copy + + assert(!sparseMat.toArray.eq(sparseCopy.toArray)) + } + + test("matrix indexing and updating") { + val m = 3 + val n = 2 + val allValues = Array(0.0, 1.0, 2.0, 3.0, 4.0, 0.0) + + val denseMat = new DenseMatrix(m, n, allValues) + + assert(denseMat(0, 1) === 3.0) + assert(denseMat(0, 1) === denseMat.values(3)) + assert(denseMat(0, 1) === denseMat(3)) + assert(denseMat(0, 0) === 0.0) + + denseMat.update(0, 0, 10.0) + assert(denseMat(0, 0) === 10.0) + assert(denseMat.values(0) === 10.0) + + val sparseValues = Array(1.0, 2.0, 3.0, 4.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(1, 2, 0, 1) + val sparseMat = new SparseMatrix(m, n, colPtrs, rowIndices, sparseValues) + + assert(sparseMat(0, 1) === 3.0) + assert(sparseMat(0, 1) === sparseMat.values(2)) + assert(sparseMat(0, 0) === 0.0) + + intercept[NoSuchElementException] { + sparseMat.update(0, 0, 10.0) + } + + sparseMat.update(0, 1, 10.0) + assert(sparseMat(0, 1) === 10.0) + assert(sparseMat.values(2) === 10.0) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index 29cc42d8cb..30b906aaa3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.util -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.scalatest.exceptions.TestFailedException object TestingUtils { @@ -169,4 +169,67 @@ object TestingUtils { override def toString = x.toString } + case class CompareMatrixRightSide( + fun: (Matrix, Matrix, Double) => Boolean, y: Matrix, eps: Double, method: String) + + /** + * Implicit class for comparing two matrices using relative tolerance or absolute tolerance. + */ + implicit class MatrixWithAlmostEquals(val x: Matrix) { + + /** + * When the difference of two vectors are within eps, returns true; otherwise, returns false. + */ + def ~=(r: CompareMatrixRightSide): Boolean = r.fun(x, r.y, r.eps) + + /** + * When the difference of two vectors are within eps, returns false; otherwise, returns true. + */ + def !~=(r: CompareMatrixRightSide): Boolean = !r.fun(x, r.y, r.eps) + + /** + * Throws exception when the difference of two vectors are NOT within eps; + * otherwise, returns true. + */ + def ~==(r: CompareMatrixRightSide): Boolean = { + if (!r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Expected \n$x\n and \n${r.y}\n to be within ${r.eps}${r.method} for all elements.", 0) + } + true + } + + /** + * Throws exception when the difference of two matrices are within eps; otherwise, returns true. + */ + def !~==(r: CompareMatrixRightSide): Boolean = { + if (r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Did not expect \n$x\n and \n${r.y}\n to be within " + + "${r.eps}${r.method} for all elements.", 0) + } + true + } + + /** + * Comparison using absolute tolerance. + */ + def absTol(eps: Double): CompareMatrixRightSide = CompareMatrixRightSide( + (x: Matrix, y: Matrix, eps: Double) => { + x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 absTol eps) + }, x, eps, ABS_TOL_MSG) + + /** + * Comparison using relative tolerance. Note that comparing against sparse vector + * with elements having value of zero will raise exception because it involves with + * comparing against zero. + */ + def relTol(eps: Double): CompareMatrixRightSide = CompareMatrixRightSide( + (x: Matrix, y: Matrix, eps: Double) => { + x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) + }, x, eps, REL_TOL_MSG) + + override def toString = x.toString + } + } -- cgit v1.2.3