aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-07-28 15:00:25 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-28 15:00:25 -0700
commit198d181dfb2c04102afe40680a4637d951e92c0b (patch)
tree2c04e7e9bc11c8f6e5bd998919f455f148b2277d /mllib
parentb88b868eb378bdb7459978842b5572a0b498f412 (diff)
downloadspark-198d181dfb2c04102afe40680a4637d951e92c0b.tar.gz
spark-198d181dfb2c04102afe40680a4637d951e92c0b.tar.bz2
spark-198d181dfb2c04102afe40680a4637d951e92c0b.zip
[SPARK-7105] [PYSPARK] [MLLIB] Support model save/load in GMM
This PR introduces save / load for GMM's in python API. Also I refactored `GaussianMixtureModel` and inherited it from `JavaModelWrapper` with model being `GaussianMixtureModelWrapper`, a wrapper which provides convenience methods to `GaussianMixtureModel` (due to serialization and deserialization issues) and I moved the creation of gaussians to the scala backend. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7617 from MechCoder/python_gmm_save_load and squashes the following commits: 9c305aa [MechCoder] [SPARK-7105] [PySpark] [MLlib] Support model save/load in GMM
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala53
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala13
2 files changed, 55 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
new file mode 100644
index 0000000000..0ec88ef77d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.api.python
+
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix}
+import org.apache.spark.mllib.clustering.GaussianMixtureModel
+
+/**
+ * Wrapper around GaussianMixtureModel to provide helper methods in Python
+ */
+private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
+ val weights: Vector = Vectors.dense(model.weights)
+ val k: Int = weights.size
+
+ /**
+ * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian
+ */
+ val gaussians: JList[Object] = {
+ val modelGaussians = model.gaussians
+ var i = 0
+ var mu = ArrayBuffer.empty[Vector]
+ var sigma = ArrayBuffer.empty[Matrix]
+ while (i < k) {
+ mu += modelGaussians(i).mu
+ sigma += modelGaussians(i).sigma
+ i += 1
+ }
+ List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
+ }
+
+ def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index fda8d5a0b0..6f080d32bb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -364,7 +364,7 @@ private[python] class PythonMLLibAPI extends Serializable {
seed: java.lang.Long,
initialModelWeights: java.util.ArrayList[Double],
initialModelMu: java.util.ArrayList[Vector],
- initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = {
+ initialModelSigma: java.util.ArrayList[Matrix]): GaussianMixtureModelWrapper = {
val gmmAlg = new GaussianMixture()
.setK(k)
.setConvergenceTol(convergenceTol)
@@ -382,16 +382,7 @@ private[python] class PythonMLLibAPI extends Serializable {
if (seed != null) gmmAlg.setSeed(seed)
try {
- val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
- var wt = ArrayBuffer.empty[Double]
- var mu = ArrayBuffer.empty[Vector]
- var sigma = ArrayBuffer.empty[Matrix]
- for (i <- 0 until model.k) {
- wt += model.weights(i)
- mu += model.gaussians(i).mu
- sigma += model.gaussians(i).sigma
- }
- List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
+ new GaussianMixtureModelWrapper(gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)))
} finally {
data.rdd.unpersist(blocking = false)
}