aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/python/mllib/gaussian_mixture_model.py4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala2
-rw-r--r--python/pyspark/mllib/clustering.py35
5 files changed, 37 insertions, 14 deletions
diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py
index 2cb8010cdc..69e836fc1d 100644
--- a/examples/src/main/python/mllib/gaussian_mixture_model.py
+++ b/examples/src/main/python/mllib/gaussian_mixture_model.py
@@ -62,5 +62,9 @@ if __name__ == "__main__":
for i in range(args.k):
print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu,
"sigma = ", model.gaussians[i].sigma.toArray()))
+ print("\n")
+ print(("The membership value of each vector to all mixture components (first 100): ",
+ model.predictSoft(data).take(100)))
+ print("\n")
print(("Cluster labels (first 100): ", model.predict(data).take(100)))
sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
index 1fce4ba7ef..90b817b23e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
@@ -58,6 +58,12 @@ object DenseGaussianMixture {
(clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
}
+ println("The membership value of each vector to all mixture components (first <= 100):")
+ val membership = clusters.predictSoft(data)
+ membership.take(100).foreach { x =>
+ print(" " + x.mkString(","))
+ }
+ println()
println("Cluster labels (first <= 100):")
val clusterLabels = clusters.predict(data)
clusterLabels.take(100).foreach { x =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
index 6a3b20c88d..a689b09341 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
@@ -40,5 +40,9 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
SerDe.dumps(JavaConverters.seqAsJavaListConverter(modelGaussians).asJava)
}
+ def predictSoft(point: Vector): Vector = {
+ Vectors.dense(model.predictSoft(point))
+ }
+
def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 16bc45bcb6..42fe27024f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -75,7 +75,7 @@ class GaussianMixtureModel @Since("1.3.0") (
*/
@Since("1.5.0")
def predict(point: Vector): Int = {
- val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
+ val r = predictSoft(point)
r.indexOf(r.max)
}
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index d22a7f4c3b..580cb512d8 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -202,16 +202,25 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
>>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
... 0.9,0.8,0.75,0.935,
- ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2))
+ ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2)
>>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001,
... maxIterations=50, seed=10)
>>> labels = model.predict(clusterdata_1).collect()
>>> labels[0]==labels[1]
False
>>> labels[1]==labels[2]
- True
+ False
>>> labels[4]==labels[5]
True
+ >>> model.predict([-0.1,-0.05])
+ 0
+ >>> softPredicted = model.predictSoft([-0.1,-0.05])
+ >>> abs(softPredicted[0] - 1.0) < 0.001
+ True
+ >>> abs(softPredicted[1] - 0.0) < 0.001
+ True
+ >>> abs(softPredicted[2] - 0.0) < 0.001
+ True
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
@@ -277,26 +286,27 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
@since('1.3.0')
def predict(self, x):
"""
- Find the cluster to which the points in 'x' has maximum membership
- in this model.
+ Find the cluster to which the point 'x' or each point in RDD 'x'
+ has maximum membership in this model.
- :param x: RDD of data points.
- :return: cluster_labels. RDD of cluster labels.
+ :param x: vector or RDD of vector represents data points.
+ :return: cluster label or RDD of cluster labels.
"""
if isinstance(x, RDD):
cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
return cluster_labels
else:
- raise TypeError("x should be represented by an RDD, "
- "but got %s." % type(x))
+ z = self.predictSoft(x)
+ return z.argmax()
@since('1.3.0')
def predictSoft(self, x):
"""
- Find the membership of each point in 'x' to all mixture components.
+ Find the membership of point 'x' or each point in RDD 'x' to all mixture components.
- :param x: RDD of data points.
- :return: membership_matrix. RDD of array of double values.
+ :param x: vector or RDD of vector represents data points.
+ :return: the membership value to all mixture components for vector 'x'
+ or each vector in RDD 'x'.
"""
if isinstance(x, RDD):
means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
@@ -304,8 +314,7 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
_convert_to_vector(self.weights), means, sigmas)
return membership_matrix.map(lambda x: pyarray.array('d', x))
else:
- raise TypeError("x should be represented by an RDD, "
- "but got %s." % type(x))
+ return self.call("predictSoft", _convert_to_vector(x)).toArray()
@classmethod
@since('1.5.0')