aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-01-09 10:27:33 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-09 10:27:33 -0800
commite9ca16ec943b9553056482d0c085eacb6046821e (patch)
tree4027c56b14b926ae40cc75fadef2b14fe33806f3
parentb6aa557300275b835cce7baa7bc8a80eb5425cbb (diff)
downloadspark-e9ca16ec943b9553056482d0c085eacb6046821e.tar.gz
spark-e9ca16ec943b9553056482d0c085eacb6046821e.tar.bz2
spark-e9ca16ec943b9553056482d0c085eacb6046821e.zip
[SPARK-5145][Mllib] Add BLAS.dsyr and use it in GaussianMixtureEM
This pr uses BLAS.dsyr to replace few implementations in GaussianMixtureEM. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #3949 from viirya/blas_dsyr and squashes the following commits: 4e4d6cf [Liang-Chi Hsieh] Add unit test. Rename function name, modify doc and style. 3f57fd2 [Liang-Chi Hsieh] Add BLAS.dsyr and use it in GaussianMixtureEM.
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala26
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala41
3 files changed, 73 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
index bdf984aee4..3a6c0e681e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.IndexedSeq
import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose}
import org.apache.spark.rdd.RDD
-import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
+import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS}
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
@@ -151,9 +151,10 @@ class GaussianMixtureEM private (
var i = 0
while (i < k) {
val mu = sums.means(i) / sums.weights(i)
- val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use BLAS.dsyr
+ BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector],
+ Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
weights(i) = sums.weights(i) / sumWeights
- gaussians(i) = new MultivariateGaussian(mu, sigma)
+ gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
i = i + 1
}
@@ -211,7 +212,8 @@ private object ExpectationSum {
p(i) /= pSum
sums.weights(i) += p(i)
sums.means(i) += x * p(i)
- sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr
+ BLAS.syr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector],
+ Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
i = i + 1
}
sums
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 9fed513bec..3414daccd7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -228,6 +228,32 @@ private[spark] object BLAS extends Serializable with Logging {
}
_nativeBLAS
}
+
+ /**
+ * A := alpha * x * x^T^ + A
+ * @param alpha a real scalar that will be multiplied to x * x^T^.
+ * @param x the vector x that contains the n elements.
+ * @param A the symmetric matrix A. Size of n x n.
+ */
+ def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
+ val mA = A.numRows
+ val nA = A.numCols
+ require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA")
+ require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}")
+
+ nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA)
+
+ // Fill lower triangular part of A
+ var i = 0
+ while (i < mA) {
+ var j = i + 1
+ while (j < nA) {
+ A(j, i) = A(i, j)
+ j += 1
+ }
+ i += 1
+ }
+ }
/**
* C := alpha * A * B + beta * C
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index 5d70c914f1..771878e925 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -127,6 +127,47 @@ class BLASSuite extends FunSuite {
}
}
+ test("syr") {
+ val dA = new DenseMatrix(4, 4,
+ Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8))
+ val x = new DenseVector(Array(0.0, 2.7, 3.5, 2.1))
+ val alpha = 0.15
+
+ val expected = new DenseMatrix(4, 4,
+ Array(0.0, 1.2, 2.2, 3.1, 1.2, 4.2935, 6.7175, 5.4505, 2.2, 6.7175, 3.6375, 4.1025, 3.1,
+ 5.4505, 4.1025, 1.4615))
+
+ syr(alpha, x, dA)
+
+ assert(dA ~== expected absTol 1e-15)
+
+ val dB =
+ new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0))
+
+ withClue("Matrix A must be a symmetric Matrix") {
+ intercept[Exception] {
+ syr(alpha, x, dB)
+ }
+ }
+
+ val dC =
+ new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8))
+
+ withClue("Size of vector must match the rank of matrix") {
+ intercept[Exception] {
+ syr(alpha, x, dC)
+ }
+ }
+
+ val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5))
+
+ withClue("Size of vector must match the rank of matrix") {
+ intercept[Exception] {
+ syr(alpha, y, dA)
+ }
+ }
+ }
+
test("gemm") {
val dA =