aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-03-25 14:45:23 -0700
committerXiangrui Meng <meng@databricks.com>2015-03-25 14:45:23 -0700
commit4fc4d0369e8240defe0ee83252426402f1a28a36 (patch)
tree0d4187756c9caf831a890fcf612b373642f5a92f
parent435337381f093f95248c8f0204e60c0b366edc81 (diff)
downloadspark-4fc4d0369e8240defe0ee83252426402f1a28a36.tar.gz
spark-4fc4d0369e8240defe0ee83252426402f1a28a36.tar.bz2
spark-4fc4d0369e8240defe0ee83252426402f1a28a36.zip
[SPARK-5987] [MLlib] Save/load for GaussianMixtureModels
Should be self explanatory. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #4986 from MechCoder/spark-5987 and squashes the following commits: 7d2cd56 [MechCoder] Iterate over dataframe in a better way e7a14cb [MechCoder] Minor 33c84f9 [MechCoder] Store as Array[Data] instead of Data[Array] 505bd57 [MechCoder] Rebased over master and used MatrixUDT 7422bb4 [MechCoder] Store sigmas as Array[Double] instead of Array[Array[Double]] b9794e4 [MechCoder] Minor cb77095 [MechCoder] [SPARK-5987] Save/load for GaussianMixtureModels
-rw-r--r--docs/mllib-clustering.md8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala96
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala52
3 files changed, 136 insertions, 20 deletions
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index 0b6db4fcb7..f5aa15b7d9 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -173,6 +173,7 @@ to the algorithm. We then output the parameters of the mixture model.
{% highlight scala %}
import org.apache.spark.mllib.clustering.GaussianMixture
+import org.apache.spark.mllib.clustering.GaussianMixtureModel
import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
@@ -182,6 +183,10 @@ val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble)))
// Cluster the data into two classes using GaussianMixture
val gmm = new GaussianMixture().setK(2).run(parsedData)
+// Save and load model
+gmm.save(sc, "myGMMModel")
+val sameModel = GaussianMixtureModel.load(sc, "myGMMModel")
+
// output parameters of max-likelihood model
for (i <- 0 until gmm.k) {
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
@@ -231,6 +236,9 @@ public class GaussianMixtureExample {
// Cluster the data into two classes using GaussianMixture
GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd());
+ // Save and load GaussianMixtureModel
+ gmm.save(sc, "myGMMModel")
+ GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel")
// Output the parameters of the mixture model
for(int j=0; j<gmm.k(); j++) {
System.out.println("weight=%f\nmu=%s\nsigma=\n%s\n",
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index af6f83c74b..ec65a3da68 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -19,11 +19,17 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseVector => BreezeVector}
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
-import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{SQLContext, Row}
/**
* :: Experimental ::
@@ -41,10 +47,16 @@ import org.apache.spark.rdd.RDD
@Experimental
class GaussianMixtureModel(
val weights: Array[Double],
- val gaussians: Array[MultivariateGaussian]) extends Serializable {
+ val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{
require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
-
+
+ override protected def formatVersion = "1.0"
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians)
+ }
+
/** Number of gaussians in mixture */
def k: Int = weights.length
@@ -83,5 +95,79 @@ class GaussianMixtureModel(
p(i) /= pSum
}
p
- }
+ }
+}
+
+@Experimental
+object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
+
+ private object SaveLoadV1_0 {
+
+ case class Data(weight: Double, mu: Vector, sigma: Matrix)
+
+ val formatVersionV1_0 = "1.0"
+
+ val classNameV1_0 = "org.apache.spark.mllib.clustering.GaussianMixtureModel"
+
+ def save(
+ sc: SparkContext,
+ path: String,
+ weights: Array[Double],
+ gaussians: Array[MultivariateGaussian]): Unit = {
+
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Create JSON metadata.
+ val metadata = compact(render
+ (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ ("k" -> weights.length)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ // Create Parquet data.
+ val dataArray = Array.tabulate(weights.length) { i =>
+ Data(weights(i), gaussians(i).mu, gaussians(i).sigma)
+ }
+ sc.parallelize(dataArray, 1).toDF().saveAsParquetFile(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): GaussianMixtureModel = {
+ val dataPath = Loader.dataPath(path)
+ val sqlContext = new SQLContext(sc)
+ val dataFrame = sqlContext.parquetFile(dataPath)
+ val dataArray = dataFrame.select("weight", "mu", "sigma").collect()
+
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ Loader.checkSchema[Data](dataFrame.schema)
+
+ val (weights, gaussians) = dataArray.map {
+ case Row(weight: Double, mu: Vector, sigma: Matrix) =>
+ (weight, new MultivariateGaussian(mu, sigma))
+ }.unzip
+
+ return new GaussianMixtureModel(weights.toArray, gaussians.toArray)
+ }
+ }
+
+ override def load(sc: SparkContext, path: String) : GaussianMixtureModel = {
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ implicit val formats = DefaultFormats
+ val k = (metadata \ "k").extract[Int]
+ val classNameV1_0 = SaveLoadV1_0.classNameV1_0
+ (loadedClassName, version) match {
+ case (classNameV1_0, "1.0") => {
+ val model = SaveLoadV1_0.load(sc, path)
+ require(model.weights.length == k,
+ s"GaussianMixtureModel requires weights of length $k " +
+ s"got weights of length ${model.weights.length}")
+ require(model.gaussians.length == k,
+ s"GaussianMixtureModel requires gaussians of length $k" +
+ s"got gaussians of length ${model.gaussians.length}")
+ model
+ }
+ case _ => throw new Exception(
+ s"GaussianMixtureModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
index 1b46a4012d..f356ffa3e3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Matrices}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
test("single cluster") {
@@ -48,13 +49,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
}
test("two clusters") {
- val data = sc.parallelize(Array(
- Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
- Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
- Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
- Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
- Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
- ))
+ val data = sc.parallelize(GaussianTestData.data)
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
@@ -105,14 +100,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
}
test("two clusters with sparse data") {
- val data = sc.parallelize(Array(
- Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
- Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
- Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
- Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
- Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
- ))
-
+ val data = sc.parallelize(GaussianTestData.data)
val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray))
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
@@ -138,4 +126,38 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
}
+
+ test("model save / load") {
+ val data = sc.parallelize(GaussianTestData.data)
+
+ val gmm = new GaussianMixture().setK(2).setSeed(0).run(data)
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ try {
+ gmm.save(sc, path)
+
+ // TODO: GaussianMixtureModel should implement equals/hashcode directly.
+ val sameModel = GaussianMixtureModel.load(sc, path)
+ assert(sameModel.k === gmm.k)
+ (0 until sameModel.k).foreach { i =>
+ assert(sameModel.gaussians(i).mu === gmm.gaussians(i).mu)
+ assert(sameModel.gaussians(i).sigma === gmm.gaussians(i).sigma)
+ }
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
+ object GaussianTestData {
+
+ val data = Array(
+ Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
+ Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
+ Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
+ Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
+ Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
+ )
+
+ }
}