aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-03-25 14:45:23 -0700
committerXiangrui Meng <meng@databricks.com>2015-03-25 14:45:23 -0700
commit4fc4d0369e8240defe0ee83252426402f1a28a36 (patch)
tree0d4187756c9caf831a890fcf612b373642f5a92f /mllib/src/test
parent435337381f093f95248c8f0204e60c0b366edc81 (diff)
downloadspark-4fc4d0369e8240defe0ee83252426402f1a28a36.tar.gz
spark-4fc4d0369e8240defe0ee83252426402f1a28a36.tar.bz2
spark-4fc4d0369e8240defe0ee83252426402f1a28a36.zip
[SPARK-5987] [MLlib] Save/load for GaussianMixtureModels
Should be self explanatory. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #4986 from MechCoder/spark-5987 and squashes the following commits: 7d2cd56 [MechCoder] Iterate over dataframe in a better way e7a14cb [MechCoder] Minor 33c84f9 [MechCoder] Store as Array[Data] instead of Data[Array] 505bd57 [MechCoder] Rebased over master and used MatrixUDT 7422bb4 [MechCoder] Store sigmas as Array[Double] instead of Array[Array[Double]] b9794e4 [MechCoder] Minor cb77095 [MechCoder] [SPARK-5987] Save/load for GaussianMixtureModels
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala52
1 files changed, 37 insertions, 15 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
index 1b46a4012d..f356ffa3e3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Matrices}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
test("single cluster") {
@@ -48,13 +49,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
}
test("two clusters") {
- val data = sc.parallelize(Array(
- Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
- Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
- Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
- Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
- Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
- ))
+ val data = sc.parallelize(GaussianTestData.data)
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
@@ -105,14 +100,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
}
test("two clusters with sparse data") {
- val data = sc.parallelize(Array(
- Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
- Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
- Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
- Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
- Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
- ))
-
+ val data = sc.parallelize(GaussianTestData.data)
val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray))
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
@@ -138,4 +126,38 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
}
+
+ test("model save / load") {
+ val data = sc.parallelize(GaussianTestData.data)
+
+ val gmm = new GaussianMixture().setK(2).setSeed(0).run(data)
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ try {
+ gmm.save(sc, path)
+
+ // TODO: GaussianMixtureModel should implement equals/hashcode directly.
+ val sameModel = GaussianMixtureModel.load(sc, path)
+ assert(sameModel.k === gmm.k)
+ (0 until sameModel.k).foreach { i =>
+ assert(sameModel.gaussians(i).mu === gmm.gaussians(i).mu)
+ assert(sameModel.gaussians(i).sigma === gmm.gaussians(i).sigma)
+ }
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
+ object GaussianTestData {
+
+ val data = Array(
+ Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
+ Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
+ Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
+ Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
+ Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
+ )
+
+ }
}