diff options
author | Dariusz Kobylarz <darek.kobylarz@gmail.com> | 2015-08-07 14:51:03 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-08-07 14:51:03 -0700 |
commit | e2fbbe73111d4624390f596a19a1799c86a05f6c (patch) | |
tree | 18ad0cb52348abe34845143c3f3578afdcfd85e7 | |
parent | 881548ab20fa4c4b635c51d956b14bd13981e2f4 (diff) | |
download | spark-e2fbbe73111d4624390f596a19a1799c86a05f6c.tar.gz spark-e2fbbe73111d4624390f596a19a1799c86a05f6c.tar.bz2 spark-e2fbbe73111d4624390f596a19a1799c86a05f6c.zip |
[SPARK-8481] [MLLIB] GaussianMixtureModel predict accepting single vector
Resubmit of [https://github.com/apache/spark/pull/6906] for adding single-vec predict to GMMs
CC: dkobylarz mengxr
To be merged with master and branch-1.5
Primary author: dkobylarz
Author: Dariusz Kobylarz <darek.kobylarz@gmail.com>
Closes #8039 from jkbradley/gmm-predict-vec and squashes the following commits:
bfbedc4 [Dariusz Kobylarz] [SPARK-8481] [MLlib] GaussianMixtureModel predict accepting single vector
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala | 13 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala | 10 |
2 files changed, 23 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index cb807c8038..76aeebd703 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -66,6 +66,12 @@ class GaussianMixtureModel( responsibilityMatrix.map(r => r.indexOf(r.max)) } + /** Maps given point to its cluster index. */ + def predict(point: Vector): Int = { + val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + r.indexOf(r.max) + } + /** Java-friendly version of [[predict()]] */ def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] @@ -84,6 +90,13 @@ class GaussianMixtureModel( } /** + * Given the input vector, return the membership values to all mixture components. + */ + def predictSoft(point: Vector): Array[Double] = { + computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + } + + /** * Compute the partial assignments for each vector */ private def computeSoftAssignments( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index b218d72f12..b636d02f78 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -148,6 +148,16 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("model prediction, parallel and local") { + val data = sc.parallelize(GaussianTestData.data) + val gmm = new GaussianMixture().setK(2).setSeed(0).run(data) + + val batchPredictions = gmm.predict(data) + batchPredictions.zip(data).collect().foreach { case (batchPred, datum) => + assert(batchPred === gmm.predict(datum)) + } + } + object GaussianTestData { val data = Array( |