aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-07-28 15:00:25 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-28 15:00:25 -0700
commit198d181dfb2c04102afe40680a4637d951e92c0b (patch)
tree2c04e7e9bc11c8f6e5bd998919f455f148b2277d
parentb88b868eb378bdb7459978842b5572a0b498f412 (diff)
downloadspark-198d181dfb2c04102afe40680a4637d951e92c0b.tar.gz
spark-198d181dfb2c04102afe40680a4637d951e92c0b.tar.bz2
spark-198d181dfb2c04102afe40680a4637d951e92c0b.zip
[SPARK-7105] [PYSPARK] [MLLIB] Support model save/load in GMM
This PR introduces save / load for GMM's in python API. Also I refactored `GaussianMixtureModel` and inherited it from `JavaModelWrapper` with model being `GaussianMixtureModelWrapper`, a wrapper which provides convenience methods to `GaussianMixtureModel` (due to serialization and deserialization issues) and I moved the creation of gaussians to the scala backend. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7617 from MechCoder/python_gmm_save_load and squashes the following commits: 9c305aa [MechCoder] [SPARK-7105] [PySpark] [MLlib] Support model save/load in GMM
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala53
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala13
-rw-r--r--python/pyspark/mllib/clustering.py75
-rw-r--r--python/pyspark/mllib/util.py6
4 files changed, 114 insertions, 33 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
new file mode 100644
index 0000000000..0ec88ef77d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.api.python
+
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix}
+import org.apache.spark.mllib.clustering.GaussianMixtureModel
+
+/**
+ * Wrapper around GaussianMixtureModel to provide helper methods in Python
+ */
+private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
+ val weights: Vector = Vectors.dense(model.weights)
+ val k: Int = weights.size
+
+ /**
+ * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian
+ */
+ val gaussians: JList[Object] = {
+ val modelGaussians = model.gaussians
+ var i = 0
+ var mu = ArrayBuffer.empty[Vector]
+ var sigma = ArrayBuffer.empty[Matrix]
+ while (i < k) {
+ mu += modelGaussians(i).mu
+ sigma += modelGaussians(i).sigma
+ i += 1
+ }
+ List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
+ }
+
+ def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index fda8d5a0b0..6f080d32bb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -364,7 +364,7 @@ private[python] class PythonMLLibAPI extends Serializable {
seed: java.lang.Long,
initialModelWeights: java.util.ArrayList[Double],
initialModelMu: java.util.ArrayList[Vector],
- initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = {
+ initialModelSigma: java.util.ArrayList[Matrix]): GaussianMixtureModelWrapper = {
val gmmAlg = new GaussianMixture()
.setK(k)
.setConvergenceTol(convergenceTol)
@@ -382,16 +382,7 @@ private[python] class PythonMLLibAPI extends Serializable {
if (seed != null) gmmAlg.setSeed(seed)
try {
- val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
- var wt = ArrayBuffer.empty[Double]
- var mu = ArrayBuffer.empty[Vector]
- var sigma = ArrayBuffer.empty[Matrix]
- for (i <- 0 until model.k) {
- wt += model.weights(i)
- mu += model.gaussians(i).mu
- sigma += model.gaussians(i).sigma
- }
- List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
+ new GaussianMixtureModelWrapper(gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)))
} finally {
data.rdd.unpersist(blocking = false)
}
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 58ad99d46e..900ade248c 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -152,11 +152,19 @@ class KMeans(object):
return KMeansModel([c.toArray() for c in centers])
-class GaussianMixtureModel(object):
+@inherit_doc
+class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
+
+ """
+ .. note:: Experimental
- """A clustering model derived from the Gaussian Mixture Model method.
+ A clustering model derived from the Gaussian Mixture Model method.
>>> from pyspark.mllib.linalg import Vectors, DenseMatrix
+ >>> from numpy.testing import assert_equal
+ >>> from shutil import rmtree
+ >>> import os, tempfile
+
>>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
... 0.9,0.8,0.75,0.935,
... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2))
@@ -169,6 +177,25 @@ class GaussianMixtureModel(object):
True
>>> labels[4]==labels[5]
True
+
+ >>> path = tempfile.mkdtemp()
+ >>> model.save(sc, path)
+ >>> sameModel = GaussianMixtureModel.load(sc, path)
+ >>> assert_equal(model.weights, sameModel.weights)
+ >>> mus, sigmas = list(
+ ... zip(*[(g.mu, g.sigma) for g in model.gaussians]))
+ >>> sameMus, sameSigmas = list(
+ ... zip(*[(g.mu, g.sigma) for g in sameModel.gaussians]))
+ >>> mus == sameMus
+ True
+ >>> sigmas == sameSigmas
+ True
+ >>> from shutil import rmtree
+ >>> try:
+ ... rmtree(path)
+ ... except OSError:
+ ... pass
+
>>> data = array([-5.1971, -2.5359, -3.8220,
... -5.2211, -5.0602, 4.7118,
... 6.8989, 3.4592, 4.6322,
@@ -182,25 +209,15 @@ class GaussianMixtureModel(object):
True
>>> labels[3]==labels[4]
True
- >>> clusterdata_3 = sc.parallelize(data.reshape(15, 1))
- >>> im = GaussianMixtureModel([0.5, 0.5],
- ... [MultivariateGaussian(Vectors.dense([-1.0]), DenseMatrix(1, 1, [1.0])),
- ... MultivariateGaussian(Vectors.dense([1.0]), DenseMatrix(1, 1, [1.0]))])
- >>> model = GaussianMixture.train(clusterdata_3, 2, initialModel=im)
"""
- def __init__(self, weights, gaussians):
- self._weights = weights
- self._gaussians = gaussians
- self._k = len(self._weights)
-
@property
def weights(self):
"""
Weights for each Gaussian distribution in the mixture, where weights[i] is
the weight for Gaussian i, and weights.sum == 1.
"""
- return self._weights
+ return array(self.call("weights"))
@property
def gaussians(self):
@@ -208,12 +225,14 @@ class GaussianMixtureModel(object):
Array of MultivariateGaussian where gaussians[i] represents
the Multivariate Gaussian (Normal) Distribution for Gaussian i.
"""
- return self._gaussians
+ return [
+ MultivariateGaussian(gaussian[0], gaussian[1])
+ for gaussian in zip(*self.call("gaussians"))]
@property
def k(self):
"""Number of gaussians in mixture."""
- return self._k
+ return len(self.weights)
def predict(self, x):
"""
@@ -238,17 +257,30 @@ class GaussianMixtureModel(object):
:return: membership_matrix. RDD of array of double values.
"""
if isinstance(x, RDD):
- means, sigmas = zip(*[(g.mu, g.sigma) for g in self._gaussians])
+ means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
- _convert_to_vector(self._weights), means, sigmas)
+ _convert_to_vector(self.weights), means, sigmas)
return membership_matrix.map(lambda x: pyarray.array('d', x))
else:
raise TypeError("x should be represented by an RDD, "
"but got %s." % type(x))
+ @classmethod
+ def load(cls, sc, path):
+ """Load the GaussianMixtureModel from disk.
+
+ :param sc: SparkContext
+ :param path: str, path to where the model is stored.
+ """
+ model = cls._load_java(sc, path)
+ wrapper = sc._jvm.GaussianMixtureModelWrapper(model)
+ return cls(wrapper)
+
class GaussianMixture(object):
"""
+ .. note:: Experimental
+
Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm.
:param data: RDD of data points
@@ -271,11 +303,10 @@ class GaussianMixture(object):
initialModelWeights = initialModel.weights
initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
- weight, mu, sigma = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
- k, convergenceTol, maxIterations, seed,
- initialModelWeights, initialModelMu, initialModelSigma)
- mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)]
- return GaussianMixtureModel(weight, mvg_obj)
+ java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
+ k, convergenceTol, maxIterations, seed,
+ initialModelWeights, initialModelMu, initialModelSigma)
+ return GaussianMixtureModel(java_model)
class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader):
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 875d3b2d64..916de2d6fc 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -21,7 +21,9 @@ import warnings
if sys.version > '3':
xrange = range
+ basestring = str
+from pyspark import SparkContext
from pyspark.mllib.common import callMLlibFunc, inherit_doc
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
@@ -223,6 +225,10 @@ class JavaSaveable(Saveable):
"""
def save(self, sc, path):
+ if not isinstance(sc, SparkContext):
+ raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
+ if not isinstance(path, basestring):
+ raise TypeError("path should be a basestring, got type %s" % type(path))
self._java_model.save(sc._jsc.sc(), path)