aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-05-15 00:18:39 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-15 00:18:39 -0700
commit94761485b207fa1f12a8410a68920300d851bf61 (patch)
tree9accedf34fe4df4d9c157ba9c4b5b05c5b69a4a9 /mllib
parentcf842d42a70398671c4bc5ebfa70f6fdb8c57c7f (diff)
downloadspark-94761485b207fa1f12a8410a68920300d851bf61.tar.gz
spark-94761485b207fa1f12a8410a68920300d851bf61.tar.bz2
spark-94761485b207fa1f12a8410a68920300d851bf61.zip
[SPARK-6258] [MLLIB] GaussianMixture Python API parity check
Implement Python API for major disparities of GaussianMixture cluster algorithm between Scala & Python ```scala GaussianMixture setInitialModel GaussianMixtureModel k ``` Author: Yanbo Liang <ybliang8@gmail.com> Closes #6087 from yanboliang/spark-6258 and squashes the following commits: b3af21c [Yanbo Liang] fix typo 2b645c1 [Yanbo Liang] fix doc 638b4b7 [Yanbo Liang] address comments b5bcade [Yanbo Liang] GaussianMixture Python API parity check
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala24
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala9
2 files changed, 22 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index f4c4775965..2fa54df6fc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -345,28 +345,40 @@ private[python] class PythonMLLibAPI extends Serializable {
* Returns a list containing weights, mean and covariance of each mixture component.
*/
def trainGaussianMixture(
- data: JavaRDD[Vector],
- k: Int,
- convergenceTol: Double,
+ data: JavaRDD[Vector],
+ k: Int,
+ convergenceTol: Double,
maxIterations: Int,
- seed: java.lang.Long): JList[Object] = {
+ seed: java.lang.Long,
+ initialModelWeights: java.util.ArrayList[Double],
+ initialModelMu: java.util.ArrayList[Vector],
+ initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = {
val gmmAlg = new GaussianMixture()
.setK(k)
.setConvergenceTol(convergenceTol)
.setMaxIterations(maxIterations)
+ if (initialModelWeights != null && initialModelMu != null && initialModelSigma != null) {
+ val gaussians = initialModelMu.asScala.toSeq.zip(initialModelSigma.asScala.toSeq).map {
+ case (x, y) => new MultivariateGaussian(x.asInstanceOf[Vector], y.asInstanceOf[Matrix])
+ }
+ val initialModel = new GaussianMixtureModel(
+ initialModelWeights.asScala.toArray, gaussians.toArray)
+ gmmAlg.setInitialModel(initialModel)
+ }
+
if (seed != null) gmmAlg.setSeed(seed)
try {
val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
var wt = ArrayBuffer.empty[Double]
- var mu = ArrayBuffer.empty[Vector]
+ var mu = ArrayBuffer.empty[Vector]
var sigma = ArrayBuffer.empty[Matrix]
for (i <- 0 until model.k) {
wt += model.weights(i)
mu += model.gaussians(i).mu
sigma += model.gaussians(i).sigma
- }
+ }
List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
} finally {
data.rdd.unpersist(blocking = false)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index ec65a3da68..c22862c130 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -38,11 +38,10 @@ import org.apache.spark.sql.{SQLContext, Row}
* are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are
* the respective mean and covariance for each Gaussian distribution i=1..k.
*
- * @param weight Weights for each Gaussian distribution in the mixture, where weight(i) is
- * the weight for Gaussian i, and weight.sum == 1
- * @param mu Means for each Gaussian in the mixture, where mu(i) is the mean for Gaussian i
- * @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the
- * covariance matrix for Gaussian i
+ * @param weights Weights for each Gaussian distribution in the mixture, where weights(i) is
+ * the weight for Gaussian i, and weights.sum == 1
+ * @param gaussians Array of MultivariateGaussian where gaussians(i) represents
+ * the Multivariate Gaussian (Normal) Distribution for Gaussian i
*/
@Experimental
class GaussianMixtureModel(