aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDariusz Kobylarz <darek.kobylarz@gmail.com>2015-08-07 14:51:03 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-07 14:51:03 -0700
commite2fbbe73111d4624390f596a19a1799c86a05f6c (patch)
tree18ad0cb52348abe34845143c3f3578afdcfd85e7 /mllib
parent881548ab20fa4c4b635c51d956b14bd13981e2f4 (diff)
downloadspark-e2fbbe73111d4624390f596a19a1799c86a05f6c.tar.gz
spark-e2fbbe73111d4624390f596a19a1799c86a05f6c.tar.bz2
spark-e2fbbe73111d4624390f596a19a1799c86a05f6c.zip
[SPARK-8481] [MLLIB] GaussianMixtureModel predict accepting single vector
Resubmit of [https://github.com/apache/spark/pull/6906] for adding single-vec predict to GMMs CC: dkobylarz mengxr To be merged with master and branch-1.5 Primary author: dkobylarz Author: Dariusz Kobylarz <darek.kobylarz@gmail.com> Closes #8039 from jkbradley/gmm-predict-vec and squashes the following commits: bfbedc4 [Dariusz Kobylarz] [SPARK-8481] [MLlib] GaussianMixtureModel predict accepting single vector
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala10
2 files changed, 23 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index cb807c8038..76aeebd703 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -66,6 +66,12 @@ class GaussianMixtureModel(
responsibilityMatrix.map(r => r.indexOf(r.max))
}
+ /** Maps given point to its cluster index. */
+ def predict(point: Vector): Int = {
+ val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
+ r.indexOf(r.max)
+ }
+
/** Java-friendly version of [[predict()]] */
def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
@@ -84,6 +90,13 @@ class GaussianMixtureModel(
}
/**
+ * Given the input vector, return the membership values to all mixture components.
+ */
+ def predictSoft(point: Vector): Array[Double] = {
+ computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
+ }
+
+ /**
* Compute the partial assignments for each vector
*/
private def computeSoftAssignments(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
index b218d72f12..b636d02f78 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
@@ -148,6 +148,16 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("model prediction, parallel and local") {
+ val data = sc.parallelize(GaussianTestData.data)
+ val gmm = new GaussianMixture().setK(2).setSeed(0).run(data)
+
+ val batchPredictions = gmm.predict(data)
+ batchPredictions.zip(data).collect().foreach { case (batchPred, datum) =>
+ assert(batchPred === gmm.predict(datum))
+ }
+ }
+
object GaussianTestData {
val data = Array(