aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-02-10 14:05:55 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-10 14:05:55 -0800
commitfd2c032f95bbee342ca539df9e44927482981659 (patch)
treed62c6a533c9ae2d06c8d8888d5197f481955969f /mllib
parentf98707c043f1be9569ec774796edb783132773a8 (diff)
downloadspark-fd2c032f95bbee342ca539df9e44927482981659.tar.gz
spark-fd2c032f95bbee342ca539df9e44927482981659.tar.bz2
spark-fd2c032f95bbee342ca539df9e44927482981659.zip
[SPARK-5021] [MLlib] Gaussian Mixture now supports Sparse Input
Following discussion in the Jira. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #4459 from MechCoder/sparse_gmm and squashes the following commits: 1b18dab [MechCoder] Rewrite syr for sparse matrices e579041 [MechCoder] Add test for covariance matrix 5cb370b [MechCoder] Separate tests for sparse data 5e096bd [MechCoder] Alphabetize and correct error message e180f4c [MechCoder] [SPARK-5021] Gaussian Mixture now supports Sparse Input
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala31
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala36
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala66
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala8
5 files changed, 125 insertions, 26 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
index 0be3014de8..80584ef5e5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
@@ -19,10 +19,12 @@ package org.apache.spark.mllib.clustering
import scala.collection.mutable.IndexedSeq
-import breeze.linalg.{DenseMatrix => BreezeMatrix, DenseVector => BreezeVector, Transpose, diag}
+import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, SparseVector => BSV,
+ Transpose, Vector => BV}
import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, Matrices, Vector, Vectors}
+import org.apache.spark.mllib.linalg.{BLAS, DenseVector, DenseMatrix, Matrices,
+ SparseVector, Vector, Vectors}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
@@ -130,7 +132,7 @@ class GaussianMixture private (
val sc = data.sparkContext
// we will operate on the data as breeze data
- val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
+ val breezeData = data.map(_.toBreeze).cache()
// Get length of the input vectors
val d = breezeData.first().length
@@ -148,7 +150,7 @@ class GaussianMixture private (
(Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
- })
+ })
}
}
@@ -169,7 +171,7 @@ class GaussianMixture private (
var i = 0
while (i < k) {
val mu = sums.means(i) / sums.weights(i)
- BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector],
+ BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu),
Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
weights(i) = sums.weights(i) / sumWeights
gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
@@ -185,8 +187,8 @@ class GaussianMixture private (
}
/** Average of dense breeze vectors */
- private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = {
- val v = BreezeVector.zeros[Double](x(0).length)
+ private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
+ val v = BDV.zeros[Double](x(0).length)
x.foreach(xi => v += xi)
v / x.length.toDouble
}
@@ -195,10 +197,10 @@ class GaussianMixture private (
* Construct matrix where diagonal entries are element-wise
* variance of input vectors (computes biased variance)
*/
- private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = {
+ private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = {
val mu = vectorMean(x)
- val ss = BreezeVector.zeros[Double](x(0).length)
- x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u)
+ val ss = BDV.zeros[Double](x(0).length)
+ x.foreach(xi => ss += (xi - mu) :^ 2.0)
diag(ss / x.length.toDouble)
}
}
@@ -207,7 +209,7 @@ class GaussianMixture private (
private object ExpectationSum {
def zero(k: Int, d: Int): ExpectationSum = {
new ExpectationSum(0.0, Array.fill(k)(0.0),
- Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
+ Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
}
// compute cluster contributions for each input point
@@ -215,19 +217,18 @@ private object ExpectationSum {
def add(
weights: Array[Double],
dists: Array[MultivariateGaussian])
- (sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = {
+ (sums: ExpectationSum, x: BV[Double]): ExpectationSum = {
val p = weights.zip(dists).map {
case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x)
}
val pSum = p.sum
sums.logLikelihood += math.log(pSum)
- val xxt = x * new Transpose(x)
var i = 0
while (i < sums.k) {
p(i) /= pSum
sums.weights(i) += p(i)
sums.means(i) += x * p(i)
- BLAS.syr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector],
+ BLAS.syr(p(i), Vectors.fromBreeze(x),
Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
i = i + 1
}
@@ -239,7 +240,7 @@ private object ExpectationSum {
private class ExpectationSum(
var logLikelihood: Double,
val weights: Array[Double],
- val means: Array[BreezeVector[Double]],
+ val means: Array[BDV[Double]],
val sigmas: Array[BreezeMatrix[Double]]) extends Serializable {
val k = weights.length
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 079f7ca564..87052e1ba8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -235,12 +235,24 @@ private[spark] object BLAS extends Serializable with Logging {
* @param x the vector x that contains the n elements.
* @param A the symmetric matrix A. Size of n x n.
*/
- def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
+ def syr(alpha: Double, x: Vector, A: DenseMatrix) {
val mA = A.numRows
val nA = A.numCols
- require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA")
+ require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA")
require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}")
+ x match {
+ case dv: DenseVector => syr(alpha, dv, A)
+ case sv: SparseVector => syr(alpha, sv, A)
+ case _ =>
+ throw new IllegalArgumentException(s"syr doesn't support vector type ${x.getClass}.")
+ }
+ }
+
+ private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
+ val nA = A.numRows
+ val mA = A.numCols
+
nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA)
// Fill lower triangular part of A
@@ -255,6 +267,26 @@ private[spark] object BLAS extends Serializable with Logging {
}
}
+ private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) {
+ val mA = A.numCols
+ val xIndices = x.indices
+ val xValues = x.values
+ val nnz = xValues.length
+ val Avalues = A.values
+
+ var i = 0
+ while (i < nnz) {
+ val multiplier = alpha * xValues(i)
+ val offset = xIndices(i) * mA
+ var j = 0
+ while (j < nnz) {
+ Avalues(xIndices(j) + offset) += multiplier * xValues(j)
+ j += 1
+ }
+ i += 1
+ }
+ }
+
/**
* C := alpha * A * B + beta * C
* @param alpha a scalar to scale the multiplication A * B.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
index fd186b5ee6..cd6add9d60 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
@@ -17,7 +17,7 @@
package org.apache.spark.mllib.stat.distribution
-import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym}
+import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV}
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
@@ -62,21 +62,21 @@ class MultivariateGaussian (
/** Returns density of this multivariate Gaussian at given point, x */
def pdf(x: Vector): Double = {
- pdf(x.toBreeze.toDenseVector)
+ pdf(x.toBreeze)
}
/** Returns the log-density of this multivariate Gaussian at given point, x */
def logpdf(x: Vector): Double = {
- logpdf(x.toBreeze.toDenseVector)
+ logpdf(x.toBreeze)
}
/** Returns density of this multivariate Gaussian at given point, x */
- private[mllib] def pdf(x: DBV[Double]): Double = {
+ private[mllib] def pdf(x: BV[Double]): Double = {
math.exp(logpdf(x))
}
/** Returns the log-density of this multivariate Gaussian at given point, x */
- private[mllib] def logpdf(x: DBV[Double]): Double = {
+ private[mllib] def logpdf(x: BV[Double]): Double = {
val delta = x - breezeMu
val v = rootSigmaInv * delta
u + v.t * v * -0.5
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
index c2cd56ea40..1b46a4012d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
@@ -31,7 +31,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
Vectors.dense(5.0, 10.0),
Vectors.dense(4.0, 11.0)
))
-
+
// expectations
val Ew = 1.0
val Emu = Vectors.dense(5.0, 10.0)
@@ -44,6 +44,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5)
assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5)
}
+
}
test("two clusters") {
@@ -54,7 +55,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
))
-
+
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
Array(0.5, 0.5),
@@ -63,7 +64,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0)))
)
)
-
+
val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
@@ -72,7 +73,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
.setK(2)
.setInitialModel(initialGmm)
.run(data)
-
+
assert(gmm.weights(0) ~== Ew(0) absTol 1E-3)
assert(gmm.weights(1) ~== Ew(1) absTol 1E-3)
assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3)
@@ -80,4 +81,61 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
}
+
+ test("single cluster with sparse data") {
+ val data = sc.parallelize(Array(
+ Vectors.sparse(3, Array(0, 2), Array(4.0, 2.0)),
+ Vectors.sparse(3, Array(0, 2), Array(2.0, 4.0)),
+ Vectors.sparse(3, Array(1), Array(6.0))
+ ))
+
+ val Ew = 1.0
+ val Emu = Vectors.dense(2.0, 2.0, 2.0)
+ val Esigma = Matrices.dense(3, 3,
+ Array(8.0 / 3.0, -4.0, 4.0 / 3.0, -4.0, 8.0, -4.0, 4.0 / 3.0, -4.0, 8.0 / 3.0)
+ )
+
+ val seeds = Array(42, 1994, 27, 11, 0)
+ seeds.foreach { seed =>
+ val gmm = new GaussianMixture().setK(1).setSeed(seed).run(data)
+ assert(gmm.weights(0) ~== Ew absTol 1E-5)
+ assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5)
+ assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5)
+ }
+ }
+
+ test("two clusters with sparse data") {
+ val data = sc.parallelize(Array(
+ Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
+ Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
+ Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
+ Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
+ Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
+ ))
+
+ val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray))
+ // we set an initial gaussian to induce expected results
+ val initialGmm = new GaussianMixtureModel(
+ Array(0.5, 0.5),
+ Array(
+ new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))),
+ new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0)))
+ )
+ )
+ val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
+ val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
+ val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
+
+ val sparseGMM = new GaussianMixture()
+ .setK(2)
+ .setInitialModel(initialGmm)
+ .run(data)
+
+ assert(sparseGMM.weights(0) ~== Ew(0) absTol 1E-3)
+ assert(sparseGMM.weights(1) ~== Ew(1) absTol 1E-3)
+ assert(sparseGMM.gaussians(0).mu ~== Emu(0) absTol 1E-3)
+ assert(sparseGMM.gaussians(1).mu ~== Emu(1) absTol 1E-3)
+ assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
+ assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index b0b78acd6d..002cb25386 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -166,6 +166,14 @@ class BLASSuite extends FunSuite {
syr(alpha, y, dA)
}
}
+
+ val xSparse = new SparseVector(4, Array(0, 2, 3), Array(1.0, 3.0, 4.0))
+ val dD = new DenseMatrix(4, 4,
+ Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8))
+ syr(0.1, xSparse, dD)
+ val expectedSparse = new DenseMatrix(4, 4,
+ Array(0.1, 1.2, 2.5, 3.5, 1.2, 3.2, 5.3, 4.6, 2.5, 5.3, 2.7, 4.2, 3.5, 4.6, 4.2, 2.4))
+ assert(dD ~== expectedSparse absTol 1e-15)
}
test("gemm") {