aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/ml/clustering.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 7bb8ab94e1..9189c02220 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -36,6 +36,14 @@ class KMeansModel(JavaModel):
"""Get the cluster centers, represented as a list of NumPy arrays."""
return [c.toArray() for c in self._call_java("clusterCenters")]
+ @since("2.0.0")
+ def computeCost(self, dataset):
+ """
+ Return the K-means cost (sum of squared distances of points to their nearest center)
+ for this model on the given data.
+ """
+ return self._call_java("computeCost", dataset)
+
@inherit_doc
class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed):
@@ -53,6 +61,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
>>> centers = model.clusterCenters()
>>> len(centers)
2
+ >>> model.computeCost(df)
+ 2.000...
>>> transformed = model.transform(df).select("features", "prediction")
>>> rows = transformed.collect()
>>> rows[0].prediction == rows[1].prediction