diff options
-rw-r--r-- | python/pyspark/ml/clustering.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 7bb8ab94e1..9189c02220 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -36,6 +36,14 @@ class KMeansModel(JavaModel): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] + @since("2.0.0") + def computeCost(self, dataset): + """ + Return the K-means cost (sum of squared distances of points to their nearest center) + for this model on the given data. + """ + return self._call_java("computeCost", dataset) + @inherit_doc class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): @@ -53,6 +61,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol >>> centers = model.clusterCenters() >>> len(centers) 2 + >>> model.computeCost(df) + 2.000... >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction |