aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala331
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala148
2 files changed, 461 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index ac56845581..a7bb413795 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -21,6 +21,7 @@ import breeze.linalg.{DenseVector => BDV}
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.impl.Utils.EPSILON
import org.apache.spark.ml.linalg._
@@ -28,7 +29,6 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat.distribution.MultivariateGaussian
import org.apache.spark.ml.util._
-import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM}
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.rdd.RDD
@@ -45,6 +45,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
/**
* Number of independent Gaussians in the mixture model. Must be greater than 1. Default: 2.
+ *
* @group param
*/
@Since("2.0.0")
@@ -57,6 +58,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
/**
* Validates and transforms the input schema.
+ *
* @param schema input schema
* @return output schema
*/
@@ -238,6 +240,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
/**
* Compute the probability (partial assignment) for each cluster for the given data point.
+ *
* @param features Data point
* @param dists Gaussians for model
* @param weights Weights for each Gaussian
@@ -323,31 +326,98 @@ class GaussianMixture @Since("2.0.0") (
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)
+ /**
+ * Number of samples per cluster to use when initializing Gaussians.
+ */
+ private val numSamples = 5
+
@Since("2.0.0")
override def fit(dataset: Dataset[_]): GaussianMixtureModel = {
transformSchema(dataset.schema, logging = true)
- val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
- case Row(point: Vector) => OldVectors.fromML(point)
- }
- val instr = Instrumentation.create(this, rdd)
+ val sc = dataset.sparkSession.sparkContext
+ val numClusters = $(k)
+
+ val instances: RDD[Vector] = dataset.select(col($(featuresCol))).rdd.map {
+ case Row(features: Vector) => features
+ }.cache()
+
+ // Extract the number of features.
+ val numFeatures = instances.first().size
+
+ val instr = Instrumentation.create(this, instances)
instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol)
+ instr.logNumFeatures(numFeatures)
+
+ val shouldDistributeGaussians = GaussianMixture.shouldDistributeGaussians(
+ numClusters, numFeatures)
+
+ // TODO: SPARK-15785 Support users supplied initial GMM.
+ val (weights, gaussians) = initRandom(instances, numClusters, numFeatures)
+
+ var logLikelihood = Double.MinValue
+ var logLikelihoodPrev = 0.0
+
+ var iter = 0
+ while (iter < $(maxIter) && math.abs(logLikelihood - logLikelihoodPrev) > $(tol)) {
+
+ val bcWeights = instances.sparkContext.broadcast(weights)
+ val bcGaussians = instances.sparkContext.broadcast(gaussians)
+
+ // aggregate the cluster contribution for all sample points
+ val sums = instances.treeAggregate(
+ new ExpectationAggregator(numFeatures, bcWeights, bcGaussians))(
+ seqOp = (c, v) => (c, v) match {
+ case (aggregator, instance) => aggregator.add(instance)
+ },
+ combOp = (c1, c2) => (c1, c2) match {
+ case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
+ })
+
+ bcWeights.destroy(blocking = false)
+ bcGaussians.destroy(blocking = false)
+
+ /*
+ Create new distributions based on the partial assignments
+ (often referred to as the "M" step in literature)
+ */
+ val sumWeights = sums.weights.sum
+
+ if (shouldDistributeGaussians) {
+ val numPartitions = math.min(numClusters, 1024)
+ val tuples = Seq.tabulate(numClusters) { i =>
+ (sums.means(i), sums.covs(i), sums.weights(i))
+ }
+ val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, cov, weight) =>
+ GaussianMixture.updateWeightsAndGaussians(mean, cov, weight, sumWeights)
+ }.collect().unzip
+ Array.copy(ws.toArray, 0, weights, 0, ws.length)
+ Array.copy(gs.toArray, 0, gaussians, 0, gs.length)
+ } else {
+ var i = 0
+ while (i < numClusters) {
+ val (weight, gaussian) = GaussianMixture.updateWeightsAndGaussians(
+ sums.means(i), sums.covs(i), sums.weights(i), sumWeights)
+ weights(i) = weight
+ gaussians(i) = gaussian
+ i += 1
+ }
+ }
+
+ logLikelihoodPrev = logLikelihood // current becomes previous
+ logLikelihood = sums.logLikelihood // this is the freshly computed log-likelihood
+ iter += 1
+ }
- val algo = new MLlibGM()
- .setK($(k))
- .setMaxIterations($(maxIter))
- .setSeed($(seed))
- .setConvergenceTol($(tol))
- val parentModel = algo.run(rdd)
- val gaussians = parentModel.gaussians.map { case g =>
- new MultivariateGaussian(g.mu.asML, g.sigma.asML)
+ val gaussianDists = gaussians.map { case (mean, covVec) =>
+ val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values)
+ new MultivariateGaussian(mean, cov)
}
- val model = copyValues(new GaussianMixtureModel(uid, parentModel.weights, gaussians))
- .setParent(this)
+
+ val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)).setParent(this)
val summary = new GaussianMixtureSummary(model.transform(dataset),
$(predictionCol), $(probabilityCol), $(featuresCol), $(k))
model.setSummary(Some(summary))
- instr.logNumFeatures(model.gaussians.head.mean.size)
instr.logSuccess(model)
model
}
@@ -356,6 +426,61 @@ class GaussianMixture @Since("2.0.0") (
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
+
+ /**
+ * Initialize weights and corresponding gaussian distributions at random.
+ *
+ * We start with uniform weights, a random mean from the data, and diagonal covariance matrices
+ * using component variances derived from the samples.
+ *
+ * @param instances The training instances.
+ * @param numClusters The number of clusters.
+ * @param numFeatures The number of features of training instance.
+ * @return The initialized weights and corresponding gaussian distributions. Note the
+ * covariance matrix of multivariate gaussian distribution is symmetric and
+ * we only save the upper triangular part as a dense vector (column major).
+ */
+ private def initRandom(
+ instances: RDD[Vector],
+ numClusters: Int,
+ numFeatures: Int): (Array[Double], Array[(DenseVector, DenseVector)]) = {
+ val samples = instances.takeSample(withReplacement = true, numClusters * numSamples, $(seed))
+ val weights: Array[Double] = Array.fill(numClusters)(1.0 / numClusters)
+ val gaussians: Array[(DenseVector, DenseVector)] = Array.tabulate(numClusters) { i =>
+ val slice = samples.view(i * numSamples, (i + 1) * numSamples)
+ val mean = {
+ val v = new DenseVector(new Array[Double](numFeatures))
+ var i = 0
+ while (i < numSamples) {
+ BLAS.axpy(1.0, slice(i), v)
+ i += 1
+ }
+ BLAS.scal(1.0 / numSamples, v)
+ v
+ }
+ /*
+ Construct matrix where diagonal entries are element-wise
+ variance of input vectors (computes biased variance).
+ Since the covariance matrix of multivariate gaussian distribution is symmetric,
+ only the upper triangular part of the matrix (column major) will be saved as
+ a dense vector in order to reduce the shuffled data size.
+ */
+ val cov = {
+ val ss = new DenseVector(new Array[Double](numFeatures)).asBreeze
+ slice.foreach(xi => ss += (xi.asBreeze - mean.asBreeze) :^ 2.0)
+ val diagVec = Vectors.fromBreeze(ss)
+ BLAS.scal(1.0 / numSamples, diagVec)
+ val covVec = new DenseVector(Array.fill[Double](
+ numFeatures * (numFeatures + 1) / 2)(0.0))
+ diagVec.toArray.zipWithIndex.foreach { case (v: Double, i: Int) =>
+ covVec.values(i + i * (i + 1) / 2) = v
+ }
+ covVec
+ }
+ (mean, cov)
+ }
+ (weights, gaussians)
+ }
}
@Since("2.0.0")
@@ -363,6 +488,180 @@ object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
@Since("2.0.0")
override def load(path: String): GaussianMixture = super.load(path)
+
+ /**
+ * Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when
+ * numFeatures > 25 except for when numClusters is very small.
+ *
+ * @param numClusters Number of clusters
+ * @param numFeatures Number of features
+ */
+ private[clustering] def shouldDistributeGaussians(
+ numClusters: Int,
+ numFeatures: Int): Boolean = {
+ ((numClusters - 1.0) / numClusters) * numFeatures > 25.0
+ }
+
+ /**
+ * Convert an n * (n + 1) / 2 dimension array representing the upper triangular part of a matrix
+ * into an n * n array representing the full symmetric matrix (column major).
+ *
+ * @param n The order of the n by n matrix.
+ * @param triangularValues The upper triangular part of the matrix packed in an array
+ * (column major).
+ * @return A dense matrix which represents the symmetric matrix in column major.
+ */
+ private[clustering] def unpackUpperTriangularMatrix(
+ n: Int,
+ triangularValues: Array[Double]): DenseMatrix = {
+ val symmetricValues = new Array[Double](n * n)
+ var r = 0
+ var i = 0
+ while (i < n) {
+ var j = 0
+ while (j <= i) {
+ symmetricValues(i * n + j) = triangularValues(r)
+ symmetricValues(j * n + i) = triangularValues(r)
+ r += 1
+ j += 1
+ }
+ i += 1
+ }
+ new DenseMatrix(n, n, symmetricValues)
+ }
+
+ /**
+ * Update the weight, mean and covariance of gaussian distribution.
+ *
+ * @param mean The mean of the gaussian distribution.
+ * @param cov The covariance matrix of the gaussian distribution. Note we only
+ * save the upper triangular part as a dense vector (column major).
+ * @param weight The weight of the gaussian distribution.
+ * @param sumWeights The sum of weights of all clusters.
+ * @return The updated weight, mean and covariance.
+ */
+ private[clustering] def updateWeightsAndGaussians(
+ mean: DenseVector,
+ cov: DenseVector,
+ weight: Double,
+ sumWeights: Double): (Double, (DenseVector, DenseVector)) = {
+ BLAS.scal(1.0 / weight, mean)
+ BLAS.spr(-weight, mean, cov)
+ BLAS.scal(1.0 / weight, cov)
+ val newWeight = weight / sumWeights
+ val newGaussian = (mean, cov)
+ (newWeight, newGaussian)
+ }
+}
+
+/**
+ * ExpectationAggregator computes the partial expectation results.
+ *
+ * @param numFeatures The number of features.
+ * @param bcWeights The broadcast weights for each Gaussian distribution in the mixture.
+ * @param bcGaussians The broadcast array of Multivariate Gaussian (Normal) Distribution
+ * in the mixture. Note only upper triangular part of the covariance
+ * matrix of each distribution is stored as dense vector (column major)
+ * in order to reduce shuffled data size.
+ */
+private class ExpectationAggregator(
+ numFeatures: Int,
+ bcWeights: Broadcast[Array[Double]],
+ bcGaussians: Broadcast[Array[(DenseVector, DenseVector)]]) extends Serializable {
+
+ private val k: Int = bcWeights.value.length
+ private var totalCnt: Long = 0L
+ private var newLogLikelihood: Double = 0.0
+ private val newWeights: Array[Double] = new Array[Double](k)
+ private val newMeans: Array[DenseVector] = Array.fill(k)(
+ new DenseVector(Array.fill[Double](numFeatures)(0.0)))
+ private val newCovs: Array[DenseVector] = Array.fill(k)(
+ new DenseVector(Array.fill[Double](numFeatures * (numFeatures + 1) / 2)(0.0)))
+
+ @transient private lazy val oldGaussians = {
+ bcGaussians.value.map { case (mean, covVec) =>
+ val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values)
+ new MultivariateGaussian(mean, cov)
+ }
+ }
+
+ def count: Long = totalCnt
+
+ def logLikelihood: Double = newLogLikelihood
+
+ def weights: Array[Double] = newWeights
+
+ def means: Array[DenseVector] = newMeans
+
+ def covs: Array[DenseVector] = newCovs
+
+ /**
+ * Add a new training instance to this ExpectationAggregator, update the weights,
+ * means and covariances for each distributions, and update the log likelihood.
+ *
+ * @param instance The instance of data point to be added.
+ * @return This ExpectationAggregator object.
+ */
+ def add(instance: Vector): this.type = {
+ val localWeights = bcWeights.value
+ val localOldGaussians = oldGaussians
+
+ val prob = new Array[Double](k)
+ var probSum = 0.0
+ var i = 0
+ while (i < k) {
+ val p = EPSILON + localWeights(i) * localOldGaussians(i).pdf(instance)
+ prob(i) = p
+ probSum += p
+ i += 1
+ }
+
+ newLogLikelihood += math.log(probSum)
+ val localNewWeights = newWeights
+ val localNewMeans = newMeans
+ val localNewCovs = newCovs
+ i = 0
+ while (i < k) {
+ prob(i) /= probSum
+ localNewWeights(i) += prob(i)
+ BLAS.axpy(prob(i), instance, localNewMeans(i))
+ BLAS.spr(prob(i), instance, localNewCovs(i))
+ i += 1
+ }
+
+ totalCnt += 1
+ this
+ }
+
+ /**
+ * Merge another ExpectationAggregator, update the weights, means and covariances
+ * for each distributions, and update the log likelihood.
+ * (Note that it's in place merging; as a result, `this` object will be modified.)
+ *
+ * @param other The other ExpectationAggregator to be merged.
+ * @return This ExpectationAggregator object.
+ */
+ def merge(other: ExpectationAggregator): this.type = {
+ if (other.count != 0) {
+ totalCnt += other.totalCnt
+
+ val localThisNewWeights = this.newWeights
+ val localOtherNewWeights = other.newWeights
+ val localThisNewMeans = this.newMeans
+ val localOtherNewMeans = other.newMeans
+ val localThisNewCovs = this.newCovs
+ val localOtherNewCovs = other.newCovs
+ var i = 0
+ while (i < k) {
+ localThisNewWeights(i) += localOtherNewWeights(i)
+ BLAS.axpy(1.0, localOtherNewMeans(i), localThisNewMeans(i))
+ BLAS.axpy(1.0, localOtherNewCovs(i), localThisNewCovs(i))
+ i += 1
+ }
+ newLogLikelihood += other.newLogLikelihood
+ }
+ this
+ }
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 07299123f8..a362aeea39 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -18,22 +18,39 @@
package org.apache.spark.ml.clustering
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.stat.distribution.MultivariateGaussian
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.{Dataset, Row}
class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
with DefaultReadWriteTest {
+ import testImplicits._
+ import GaussianMixtureSuite._
+
final val k = 5
+ private val seed = 538009335
@transient var dataset: Dataset[_] = _
+ @transient var denseDataset: Dataset[_] = _
+ @transient var sparseDataset: Dataset[_] = _
+ @transient var decompositionDataset: Dataset[_] = _
+ @transient var rDataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
+ denseDataset = denseData.map(FeatureData).toDF()
+ sparseDataset = denseData.map { point =>
+ FeatureData(point.toSparse)
+ }.toDF()
+ decompositionDataset = decompositionData.map(FeatureData).toDF()
+ rDataset = rData.map(FeatureData).toDF()
}
test("default parameters") {
@@ -94,6 +111,15 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
assert(transformed.columns.contains(column))
}
+ // Check prediction matches the highest probability, and probabilities sum to one.
+ transformed.select(predictionColName, probabilityColName).collect().foreach {
+ case Row(pred: Int, prob: Vector) =>
+ val probArray = prob.toArray
+ val predFromProb = probArray.zipWithIndex.maxBy(_._1)._2
+ assert(pred === predFromProb)
+ assert(probArray.sum ~== 1.0 absTol 1E-5)
+ }
+
// Check validity of model summary
val numRows = dataset.count()
assert(model.hasSummary)
@@ -126,9 +152,93 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
testEstimatorAndModelReadWrite(gm, dataset,
GaussianMixtureSuite.allParamSettings, checkModelData)
}
+
+ test("univariate dense/sparse data with two clusters") {
+ val weights = Array(2.0 / 3.0, 1.0 / 3.0)
+ val means = Array(Vectors.dense(5.1604), Vectors.dense(-4.3673))
+ val covs = Array(Matrices.dense(1, 1, Array(0.86644)), Matrices.dense(1, 1, Array(1.1098)))
+ val gaussians = means.zip(covs).map { case (mean, cov) =>
+ new MultivariateGaussian(mean, cov)
+ }
+ val expected = new GaussianMixtureModel("dummy", weights, gaussians)
+
+ Seq(denseDataset, sparseDataset).foreach { dataset =>
+ val actual = new GaussianMixture().setK(2).setSeed(seed).fit(dataset)
+ modelEquals(expected, actual)
+ }
+ }
+
+ test("check distributed decomposition") {
+ val k = 5
+ val d = decompositionData.head.size
+ assert(GaussianMixture.shouldDistributeGaussians(k, d))
+
+ val gmm = new GaussianMixture().setK(k).setSeed(seed).fit(decompositionDataset)
+ assert(gmm.getK === k)
+ }
+
+ test("multivariate data and check againt R mvnormalmixEM") {
+ /*
+ Using the following R code to generate data and train the model using mixtools package.
+ library(mvtnorm)
+ library(mixtools)
+ set.seed(1)
+ a <- rmvnorm(7, c(0, 0))
+ b <- rmvnorm(8, c(10, 10))
+ data <- rbind(a, b)
+ model <- mvnormalmixEM(data, k = 2)
+ model$lambda
+
+ [1] 0.4666667 0.5333333
+
+ model$mu
+
+ [1] 0.11731091 -0.06192351
+ [1] 10.363673 9.897081
+
+ model$sigma
+
+ [[1]]
+ [,1] [,2]
+ [1,] 0.62049934 0.06880802
+ [2,] 0.06880802 1.27431874
+
+ [[2]]
+ [,1] [,2]
+ [1,] 0.2961543 0.160783
+ [2,] 0.1607830 1.008878
+ */
+ val weights = Array(0.5333333, 0.4666667)
+ val means = Array(Vectors.dense(10.363673, 9.897081), Vectors.dense(0.11731091, -0.06192351))
+ val covs = Array(Matrices.dense(2, 2, Array(0.2961543, 0.1607830, 0.160783, 1.008878)),
+ Matrices.dense(2, 2, Array(0.62049934, 0.06880802, 0.06880802, 1.27431874)))
+ val gaussians = means.zip(covs).map { case (mean, cov) =>
+ new MultivariateGaussian(mean, cov)
+ }
+
+ val expected = new GaussianMixtureModel("dummy", weights, gaussians)
+ val actual = new GaussianMixture().setK(2).setSeed(seed).fit(rDataset)
+ modelEquals(expected, actual)
+ }
+
+ test("upper triangular matrix unpacking") {
+ /*
+ The full symmetric matrix is as follows:
+ 1.0 2.5 3.8 0.9
+ 2.5 2.0 7.2 3.8
+ 3.8 7.2 3.0 1.0
+ 0.9 3.8 1.0 4.0
+ */
+ val triangularValues = Array(1.0, 2.5, 2.0, 3.8, 7.2, 3.0, 0.9, 3.8, 1.0, 4.0)
+ val symmetricValues = Array(1.0, 2.5, 3.8, 0.9, 2.5, 2.0, 7.2, 3.8,
+ 3.8, 7.2, 3.0, 1.0, 0.9, 3.8, 1.0, 4.0)
+ val symmetricMatrix = new DenseMatrix(4, 4, symmetricValues)
+ val expectedMatrix = GaussianMixture.unpackUpperTriangularMatrix(4, triangularValues)
+ assert(symmetricMatrix === expectedMatrix)
+ }
}
-object GaussianMixtureSuite {
+object GaussianMixtureSuite extends SparkFunSuite {
/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
@@ -141,4 +251,38 @@ object GaussianMixtureSuite {
"maxIter" -> 2,
"tol" -> 0.01
)
+
+ val denseData = Seq(
+ Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
+ Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
+ Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
+ Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
+ Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
+ )
+
+ val decompositionData: Seq[Vector] = Seq.tabulate(25) { i: Int =>
+ Vectors.dense(Array.tabulate(50)(i + _.toDouble))
+ }
+
+ val rData = Seq(
+ Vectors.dense(-0.6264538, 0.1836433), Vectors.dense(-0.8356286, 1.5952808),
+ Vectors.dense(0.3295078, -0.8204684), Vectors.dense(0.4874291, 0.7383247),
+ Vectors.dense(0.5757814, -0.3053884), Vectors.dense(1.5117812, 0.3898432),
+ Vectors.dense(-0.6212406, -2.2146999), Vectors.dense(11.1249309, 9.9550664),
+ Vectors.dense(9.9838097, 10.9438362), Vectors.dense(10.8212212, 10.5939013),
+ Vectors.dense(10.9189774, 10.7821363), Vectors.dense(10.0745650, 8.0106483),
+ Vectors.dense(10.6198257, 9.9438713), Vectors.dense(9.8442045, 8.5292476),
+ Vectors.dense(9.5218499, 10.4179416)
+ )
+
+ case class FeatureData(features: Vector)
+
+ def modelEquals(m1: GaussianMixtureModel, m2: GaussianMixtureModel): Unit = {
+ assert(m1.weights.length === m2.weights.length)
+ for (i <- m1.weights.indices) {
+ assert(m1.weights(i) ~== m2.weights(i) absTol 1E-3)
+ assert(m1.gaussians(i).mean ~== m2.gaussians(i).mean absTol 1E-3)
+ assert(m1.gaussians(i).cov ~== m2.gaussians(i).cov absTol 1E-3)
+ }
+ }
}