aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2017-01-09 21:38:46 -0800
committerYanbo Liang <ybliang8@gmail.com>2017-01-09 21:38:46 -0800
commit3ef6d98a803fdff182ab4556c3273ec5fa0ff002 (patch)
tree1d8dba974664353e4146b7d9ca801e0473d0d9f4 /mllib
parentfaabe69cc081145f43f9c68db1a7a8c5c39684fb (diff)
downloadspark-3ef6d98a803fdff182ab4556c3273ec5fa0ff002.tar.gz
spark-3ef6d98a803fdff182ab4556c3273ec5fa0ff002.tar.bz2
spark-3ef6d98a803fdff182ab4556c3273ec5fa0ff002.zip
[SPARK-17847][ML] Reduce shuffled data size of GaussianMixture & copy the implementation from mllib to ml
## What changes were proposed in this pull request? Copy `GaussianMixture` implementation from mllib to ml, then we can add new features to it. I left mllib `GaussianMixture` untouched, unlike some other algorithms to wrap the ml implementation. For the following reasons: - mllib `GaussianMixture` allows k == 1, but ml does not. - mllib `GaussianMixture` supports setting initial model, but ml does not support currently. (We will definitely add this feature for ml in the future) We can get around these issues to make mllib as a wrapper calling into ml, but I'd prefer to leave mllib untouched which can make ml clean. Meanwhile, There is a big performance improvement for `GaussianMixture` in this PR. Since the covariance matrix of multivariate gaussian distribution is symmetric, we can only store the upper triangular part of the matrix and it will greatly reduce the shuffled data size. In my test, this change will reduce shuffled data size by about 50% and accelerate the job execution. Before this PR: ![image](https://cloud.githubusercontent.com/assets/1962026/19641622/4bb017ac-9996-11e6-8ece-83db184b620a.png) After this PR: ![image](https://cloud.githubusercontent.com/assets/1962026/19641635/629c21fe-9996-11e6-91e9-83ab74ae0126.png) ## How was this patch tested? Existing tests and added new tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #15413 from yanboliang/spark-17847.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala331
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala148
2 files changed, 461 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index ac56845581..a7bb413795 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -21,6 +21,7 @@ import breeze.linalg.{DenseVector => BDV}
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.impl.Utils.EPSILON
import org.apache.spark.ml.linalg._
@@ -28,7 +29,6 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat.distribution.MultivariateGaussian
import org.apache.spark.ml.util._
-import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM}
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.rdd.RDD
@@ -45,6 +45,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
/**
* Number of independent Gaussians in the mixture model. Must be greater than 1. Default: 2.
+ *
* @group param
*/
@Since("2.0.0")
@@ -57,6 +58,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
/**
* Validates and transforms the input schema.
+ *
* @param schema input schema
* @return output schema
*/
@@ -238,6 +240,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
/**
* Compute the probability (partial assignment) for each cluster for the given data point.
+ *
* @param features Data point
* @param dists Gaussians for model
* @param weights Weights for each Gaussian
@@ -323,31 +326,98 @@ class GaussianMixture @Since("2.0.0") (
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)
+ /**
+ * Number of samples per cluster to use when initializing Gaussians.
+ */
+ private val numSamples = 5
+
@Since("2.0.0")
override def fit(dataset: Dataset[_]): GaussianMixtureModel = {
transformSchema(dataset.schema, logging = true)
- val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
- case Row(point: Vector) => OldVectors.fromML(point)
- }
- val instr = Instrumentation.create(this, rdd)
+ val sc = dataset.sparkSession.sparkContext
+ val numClusters = $(k)
+
+ val instances: RDD[Vector] = dataset.select(col($(featuresCol))).rdd.map {
+ case Row(features: Vector) => features
+ }.cache()
+
+ // Extract the number of features.
+ val numFeatures = instances.first().size
+
+ val instr = Instrumentation.create(this, instances)
instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol)
+ instr.logNumFeatures(numFeatures)
+
+ val shouldDistributeGaussians = GaussianMixture.shouldDistributeGaussians(
+ numClusters, numFeatures)
+
+ // TODO: SPARK-15785 Support users supplied initial GMM.
+ val (weights, gaussians) = initRandom(instances, numClusters, numFeatures)
+
+ var logLikelihood = Double.MinValue
+ var logLikelihoodPrev = 0.0
+
+ var iter = 0
+ while (iter < $(maxIter) && math.abs(logLikelihood - logLikelihoodPrev) > $(tol)) {
+
+ val bcWeights = instances.sparkContext.broadcast(weights)
+ val bcGaussians = instances.sparkContext.broadcast(gaussians)
+
+ // aggregate the cluster contribution for all sample points
+ val sums = instances.treeAggregate(
+ new ExpectationAggregator(numFeatures, bcWeights, bcGaussians))(
+ seqOp = (c, v) => (c, v) match {
+ case (aggregator, instance) => aggregator.add(instance)
+ },
+ combOp = (c1, c2) => (c1, c2) match {
+ case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
+ })
+
+ bcWeights.destroy(blocking = false)
+ bcGaussians.destroy(blocking = false)
+
+ /*
+ Create new distributions based on the partial assignments
+ (often referred to as the "M" step in literature)
+ */
+ val sumWeights = sums.weights.sum
+
+ if (shouldDistributeGaussians) {
+ val numPartitions = math.min(numClusters, 1024)
+ val tuples = Seq.tabulate(numClusters) { i =>
+ (sums.means(i), sums.covs(i), sums.weights(i))
+ }
+ val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, cov, weight) =>
+ GaussianMixture.updateWeightsAndGaussians(mean, cov, weight, sumWeights)
+ }.collect().unzip
+ Array.copy(ws.toArray, 0, weights, 0, ws.length)
+ Array.copy(gs.toArray, 0, gaussians, 0, gs.length)
+ } else {
+ var i = 0
+ while (i < numClusters) {
+ val (weight, gaussian) = GaussianMixture.updateWeightsAndGaussians(
+ sums.means(i), sums.covs(i), sums.weights(i), sumWeights)
+ weights(i) = weight
+ gaussians(i) = gaussian
+ i += 1
+ }
+ }
+
+ logLikelihoodPrev = logLikelihood // current becomes previous
+ logLikelihood = sums.logLikelihood // this is the freshly computed log-likelihood
+ iter += 1
+ }
- val algo = new MLlibGM()
- .setK($(k))
- .setMaxIterations($(maxIter))
- .setSeed($(seed))
- .setConvergenceTol($(tol))
- val parentModel = algo.run(rdd)
- val gaussians = parentModel.gaussians.map { case g =>
- new MultivariateGaussian(g.mu.asML, g.sigma.asML)
+ val gaussianDists = gaussians.map { case (mean, covVec) =>
+ val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values)
+ new MultivariateGaussian(mean, cov)
}
- val model = copyValues(new GaussianMixtureModel(uid, parentModel.weights, gaussians))
- .setParent(this)
+
+ val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)).setParent(this)
val summary = new GaussianMixtureSummary(model.transform(dataset),
$(predictionCol), $(probabilityCol), $(featuresCol), $(k))
model.setSummary(Some(summary))
- instr.logNumFeatures(model.gaussians.head.mean.size)
instr.logSuccess(model)
model
}
@@ -356,6 +426,61 @@ class GaussianMixture @Since("2.0.0") (
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
+
+ /**
+ * Initialize weights and corresponding gaussian distributions at random.
+ *
+ * We start with uniform weights, a random mean from the data, and diagonal covariance matrices
+ * using component variances derived from the samples.
+ *
+ * @param instances The training instances.
+ * @param numClusters The number of clusters.
+ * @param numFeatures The number of features of training instance.
+ * @return The initialized weights and corresponding gaussian distributions. Note the
+ * covariance matrix of multivariate gaussian distribution is symmetric and
+ * we only save the upper triangular part as a dense vector (column major).
+ */
+ private def initRandom(
+ instances: RDD[Vector],
+ numClusters: Int,
+ numFeatures: Int): (Array[Double], Array[(DenseVector, DenseVector)]) = {
+ val samples = instances.takeSample(withReplacement = true, numClusters * numSamples, $(seed))
+ val weights: Array[Double] = Array.fill(numClusters)(1.0 / numClusters)
+ val gaussians: Array[(DenseVector, DenseVector)] = Array.tabulate(numClusters) { i =>
+ val slice = samples.view(i * numSamples, (i + 1) * numSamples)
+ val mean = {
+ val v = new DenseVector(new Array[Double](numFeatures))
+ var i = 0
+ while (i < numSamples) {
+ BLAS.axpy(1.0, slice(i), v)
+ i += 1
+ }
+ BLAS.scal(1.0 / numSamples, v)
+ v
+ }
+ /*
+ Construct matrix where diagonal entries are element-wise
+ variance of input vectors (computes biased variance).
+ Since the covariance matrix of multivariate gaussian distribution is symmetric,
+ only the upper triangular part of the matrix (column major) will be saved as
+ a dense vector in order to reduce the shuffled data size.
+ */
+ val cov = {
+ val ss = new DenseVector(new Array[Double](numFeatures)).asBreeze
+ slice.foreach(xi => ss += (xi.asBreeze - mean.asBreeze) :^ 2.0)
+ val diagVec = Vectors.fromBreeze(ss)
+ BLAS.scal(1.0 / numSamples, diagVec)
+ val covVec = new DenseVector(Array.fill[Double](
+ numFeatures * (numFeatures + 1) / 2)(0.0))
+ diagVec.toArray.zipWithIndex.foreach { case (v: Double, i: Int) =>
+ covVec.values(i + i * (i + 1) / 2) = v
+ }
+ covVec
+ }
+ (mean, cov)
+ }
+ (weights, gaussians)
+ }
}
@Since("2.0.0")
@@ -363,6 +488,180 @@ object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
@Since("2.0.0")
override def load(path: String): GaussianMixture = super.load(path)
+
+ /**
+ * Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when
+ * numFeatures > 25 except for when numClusters is very small.
+ *
+ * @param numClusters Number of clusters
+ * @param numFeatures Number of features
+ */
+ private[clustering] def shouldDistributeGaussians(
+ numClusters: Int,
+ numFeatures: Int): Boolean = {
+ ((numClusters - 1.0) / numClusters) * numFeatures > 25.0
+ }
+
+ /**
+ * Convert an n * (n + 1) / 2 dimension array representing the upper triangular part of a matrix
+ * into an n * n array representing the full symmetric matrix (column major).
+ *
+ * @param n The order of the n by n matrix.
+ * @param triangularValues The upper triangular part of the matrix packed in an array
+ * (column major).
+ * @return A dense matrix which represents the symmetric matrix in column major.
+ */
+ private[clustering] def unpackUpperTriangularMatrix(
+ n: Int,
+ triangularValues: Array[Double]): DenseMatrix = {
+ val symmetricValues = new Array[Double](n * n)
+ var r = 0
+ var i = 0
+ while (i < n) {
+ var j = 0
+ while (j <= i) {
+ symmetricValues(i * n + j) = triangularValues(r)
+ symmetricValues(j * n + i) = triangularValues(r)
+ r += 1
+ j += 1
+ }
+ i += 1
+ }
+ new DenseMatrix(n, n, symmetricValues)
+ }
+
+ /**
+ * Update the weight, mean and covariance of gaussian distribution.
+ *
+ * @param mean The mean of the gaussian distribution.
+ * @param cov The covariance matrix of the gaussian distribution. Note we only
+ * save the upper triangular part as a dense vector (column major).
+ * @param weight The weight of the gaussian distribution.
+ * @param sumWeights The sum of weights of all clusters.
+ * @return The updated weight, mean and covariance.
+ */
+ private[clustering] def updateWeightsAndGaussians(
+ mean: DenseVector,
+ cov: DenseVector,
+ weight: Double,
+ sumWeights: Double): (Double, (DenseVector, DenseVector)) = {
+ BLAS.scal(1.0 / weight, mean)
+ BLAS.spr(-weight, mean, cov)
+ BLAS.scal(1.0 / weight, cov)
+ val newWeight = weight / sumWeights
+ val newGaussian = (mean, cov)
+ (newWeight, newGaussian)
+ }
+}
+
+/**
+ * ExpectationAggregator computes the partial expectation results.
+ *
+ * @param numFeatures The number of features.
+ * @param bcWeights The broadcast weights for each Gaussian distribution in the mixture.
+ * @param bcGaussians The broadcast array of Multivariate Gaussian (Normal) Distribution
+ * in the mixture. Note only upper triangular part of the covariance
+ * matrix of each distribution is stored as dense vector (column major)
+ * in order to reduce shuffled data size.
+ */
+private class ExpectationAggregator(
+ numFeatures: Int,
+ bcWeights: Broadcast[Array[Double]],
+ bcGaussians: Broadcast[Array[(DenseVector, DenseVector)]]) extends Serializable {
+
+ private val k: Int = bcWeights.value.length
+ private var totalCnt: Long = 0L
+ private var newLogLikelihood: Double = 0.0
+ private val newWeights: Array[Double] = new Array[Double](k)
+ private val newMeans: Array[DenseVector] = Array.fill(k)(
+ new DenseVector(Array.fill[Double](numFeatures)(0.0)))
+ private val newCovs: Array[DenseVector] = Array.fill(k)(
+ new DenseVector(Array.fill[Double](numFeatures * (numFeatures + 1) / 2)(0.0)))
+
+ @transient private lazy val oldGaussians = {
+ bcGaussians.value.map { case (mean, covVec) =>
+ val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values)
+ new MultivariateGaussian(mean, cov)
+ }
+ }
+
+ def count: Long = totalCnt
+
+ def logLikelihood: Double = newLogLikelihood
+
+ def weights: Array[Double] = newWeights
+
+ def means: Array[DenseVector] = newMeans
+
+ def covs: Array[DenseVector] = newCovs
+
+ /**
+ * Add a new training instance to this ExpectationAggregator, update the weights,
+ * means and covariances for each distributions, and update the log likelihood.
+ *
+ * @param instance The instance of data point to be added.
+ * @return This ExpectationAggregator object.
+ */
+ def add(instance: Vector): this.type = {
+ val localWeights = bcWeights.value
+ val localOldGaussians = oldGaussians
+
+ val prob = new Array[Double](k)
+ var probSum = 0.0
+ var i = 0
+ while (i < k) {
+ val p = EPSILON + localWeights(i) * localOldGaussians(i).pdf(instance)
+ prob(i) = p
+ probSum += p
+ i += 1
+ }
+
+ newLogLikelihood += math.log(probSum)
+ val localNewWeights = newWeights
+ val localNewMeans = newMeans
+ val localNewCovs = newCovs
+ i = 0
+ while (i < k) {
+ prob(i) /= probSum
+ localNewWeights(i) += prob(i)
+ BLAS.axpy(prob(i), instance, localNewMeans(i))
+ BLAS.spr(prob(i), instance, localNewCovs(i))
+ i += 1
+ }
+
+ totalCnt += 1
+ this
+ }
+
+ /**
+ * Merge another ExpectationAggregator, update the weights, means and covariances
+ * for each distributions, and update the log likelihood.
+ * (Note that it's in place merging; as a result, `this` object will be modified.)
+ *
+ * @param other The other ExpectationAggregator to be merged.
+ * @return This ExpectationAggregator object.
+ */
+ def merge(other: ExpectationAggregator): this.type = {
+ if (other.count != 0) {
+ totalCnt += other.totalCnt
+
+ val localThisNewWeights = this.newWeights
+ val localOtherNewWeights = other.newWeights
+ val localThisNewMeans = this.newMeans
+ val localOtherNewMeans = other.newMeans
+ val localThisNewCovs = this.newCovs
+ val localOtherNewCovs = other.newCovs
+ var i = 0
+ while (i < k) {
+ localThisNewWeights(i) += localOtherNewWeights(i)
+ BLAS.axpy(1.0, localOtherNewMeans(i), localThisNewMeans(i))
+ BLAS.axpy(1.0, localOtherNewCovs(i), localThisNewCovs(i))
+ i += 1
+ }
+ newLogLikelihood += other.newLogLikelihood
+ }
+ this
+ }
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 07299123f8..a362aeea39 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -18,22 +18,39 @@
package org.apache.spark.ml.clustering
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.stat.distribution.MultivariateGaussian
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.{Dataset, Row}
class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
with DefaultReadWriteTest {
+ import testImplicits._
+ import GaussianMixtureSuite._
+
final val k = 5
+ private val seed = 538009335
@transient var dataset: Dataset[_] = _
+ @transient var denseDataset: Dataset[_] = _
+ @transient var sparseDataset: Dataset[_] = _
+ @transient var decompositionDataset: Dataset[_] = _
+ @transient var rDataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
+ denseDataset = denseData.map(FeatureData).toDF()
+ sparseDataset = denseData.map { point =>
+ FeatureData(point.toSparse)
+ }.toDF()
+ decompositionDataset = decompositionData.map(FeatureData).toDF()
+ rDataset = rData.map(FeatureData).toDF()
}
test("default parameters") {
@@ -94,6 +111,15 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
assert(transformed.columns.contains(column))
}
+ // Check prediction matches the highest probability, and probabilities sum to one.
+ transformed.select(predictionColName, probabilityColName).collect().foreach {
+ case Row(pred: Int, prob: Vector) =>
+ val probArray = prob.toArray
+ val predFromProb = probArray.zipWithIndex.maxBy(_._1)._2
+ assert(pred === predFromProb)
+ assert(probArray.sum ~== 1.0 absTol 1E-5)
+ }
+
// Check validity of model summary
val numRows = dataset.count()
assert(model.hasSummary)
@@ -126,9 +152,93 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
testEstimatorAndModelReadWrite(gm, dataset,
GaussianMixtureSuite.allParamSettings, checkModelData)
}
+
+ test("univariate dense/sparse data with two clusters") {
+ val weights = Array(2.0 / 3.0, 1.0 / 3.0)
+ val means = Array(Vectors.dense(5.1604), Vectors.dense(-4.3673))
+ val covs = Array(Matrices.dense(1, 1, Array(0.86644)), Matrices.dense(1, 1, Array(1.1098)))
+ val gaussians = means.zip(covs).map { case (mean, cov) =>
+ new MultivariateGaussian(mean, cov)
+ }
+ val expected = new GaussianMixtureModel("dummy", weights, gaussians)
+
+ Seq(denseDataset, sparseDataset).foreach { dataset =>
+ val actual = new GaussianMixture().setK(2).setSeed(seed).fit(dataset)
+ modelEquals(expected, actual)
+ }
+ }
+
+ test("check distributed decomposition") {
+ val k = 5
+ val d = decompositionData.head.size
+ assert(GaussianMixture.shouldDistributeGaussians(k, d))
+
+ val gmm = new GaussianMixture().setK(k).setSeed(seed).fit(decompositionDataset)
+ assert(gmm.getK === k)
+ }
+
+ test("multivariate data and check againt R mvnormalmixEM") {
+ /*
+ Using the following R code to generate data and train the model using mixtools package.
+ library(mvtnorm)
+ library(mixtools)
+ set.seed(1)
+ a <- rmvnorm(7, c(0, 0))
+ b <- rmvnorm(8, c(10, 10))
+ data <- rbind(a, b)
+ model <- mvnormalmixEM(data, k = 2)
+ model$lambda
+
+ [1] 0.4666667 0.5333333
+
+ model$mu
+
+ [1] 0.11731091 -0.06192351
+ [1] 10.363673 9.897081
+
+ model$sigma
+
+ [[1]]
+ [,1] [,2]
+ [1,] 0.62049934 0.06880802
+ [2,] 0.06880802 1.27431874
+
+ [[2]]
+ [,1] [,2]
+ [1,] 0.2961543 0.160783
+ [2,] 0.1607830 1.008878
+ */
+ val weights = Array(0.5333333, 0.4666667)
+ val means = Array(Vectors.dense(10.363673, 9.897081), Vectors.dense(0.11731091, -0.06192351))
+ val covs = Array(Matrices.dense(2, 2, Array(0.2961543, 0.1607830, 0.160783, 1.008878)),
+ Matrices.dense(2, 2, Array(0.62049934, 0.06880802, 0.06880802, 1.27431874)))
+ val gaussians = means.zip(covs).map { case (mean, cov) =>
+ new MultivariateGaussian(mean, cov)
+ }
+
+ val expected = new GaussianMixtureModel("dummy", weights, gaussians)
+ val actual = new GaussianMixture().setK(2).setSeed(seed).fit(rDataset)
+ modelEquals(expected, actual)
+ }
+
+ test("upper triangular matrix unpacking") {
+ /*
+ The full symmetric matrix is as follows:
+ 1.0 2.5 3.8 0.9
+ 2.5 2.0 7.2 3.8
+ 3.8 7.2 3.0 1.0
+ 0.9 3.8 1.0 4.0
+ */
+ val triangularValues = Array(1.0, 2.5, 2.0, 3.8, 7.2, 3.0, 0.9, 3.8, 1.0, 4.0)
+ val symmetricValues = Array(1.0, 2.5, 3.8, 0.9, 2.5, 2.0, 7.2, 3.8,
+ 3.8, 7.2, 3.0, 1.0, 0.9, 3.8, 1.0, 4.0)
+ val symmetricMatrix = new DenseMatrix(4, 4, symmetricValues)
+ val expectedMatrix = GaussianMixture.unpackUpperTriangularMatrix(4, triangularValues)
+ assert(symmetricMatrix === expectedMatrix)
+ }
}
-object GaussianMixtureSuite {
+object GaussianMixtureSuite extends SparkFunSuite {
/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
@@ -141,4 +251,38 @@ object GaussianMixtureSuite {
"maxIter" -> 2,
"tol" -> 0.01
)
+
+ val denseData = Seq(
+ Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
+ Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
+ Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
+ Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
+ Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
+ )
+
+ val decompositionData: Seq[Vector] = Seq.tabulate(25) { i: Int =>
+ Vectors.dense(Array.tabulate(50)(i + _.toDouble))
+ }
+
+ val rData = Seq(
+ Vectors.dense(-0.6264538, 0.1836433), Vectors.dense(-0.8356286, 1.5952808),
+ Vectors.dense(0.3295078, -0.8204684), Vectors.dense(0.4874291, 0.7383247),
+ Vectors.dense(0.5757814, -0.3053884), Vectors.dense(1.5117812, 0.3898432),
+ Vectors.dense(-0.6212406, -2.2146999), Vectors.dense(11.1249309, 9.9550664),
+ Vectors.dense(9.9838097, 10.9438362), Vectors.dense(10.8212212, 10.5939013),
+ Vectors.dense(10.9189774, 10.7821363), Vectors.dense(10.0745650, 8.0106483),
+ Vectors.dense(10.6198257, 9.9438713), Vectors.dense(9.8442045, 8.5292476),
+ Vectors.dense(9.5218499, 10.4179416)
+ )
+
+ case class FeatureData(features: Vector)
+
+ def modelEquals(m1: GaussianMixtureModel, m2: GaussianMixtureModel): Unit = {
+ assert(m1.weights.length === m2.weights.length)
+ for (i <- m1.weights.indices) {
+ assert(m1.weights(i) ~== m2.weights(i) absTol 1E-3)
+ assert(m1.gaussians(i).mean ~== m2.gaussians(i).mean absTol 1E-3)
+ assert(m1.gaussians(i).cov ~== m2.gaussians(i).cov absTol 1E-3)
+ }
+ }
}