diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-01-06 10:50:02 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-01-06 10:50:02 -0800 |
commit | 95eb65163391b9e910277a948b72efccf6136e0c (patch) | |
tree | 82ad62d5db87ab5921ca448148a8376888ec62c7 /python/pyspark/ml/clustering.py | |
parent | 007da1a9dc3bb912da841cc0f5832a4fa28e6d9d (diff) | |
download | spark-95eb65163391b9e910277a948b72efccf6136e0c.tar.gz spark-95eb65163391b9e910277a948b72efccf6136e0c.tar.bz2 spark-95eb65163391b9e910277a948b72efccf6136e0c.zip |
[SPARK-11945][ML][PYSPARK] Add computeCost to KMeansModel for PySpark spark.ml
Add ```computeCost``` to ```KMeansModel``` as evaluator for PySpark spark.ml.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #9931 from yanboliang/SPARK-11945.
Diffstat (limited to 'python/pyspark/ml/clustering.py')
-rw-r--r-- | python/pyspark/ml/clustering.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 7bb8ab94e1..9189c02220 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -36,6 +36,14 @@ class KMeansModel(JavaModel): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] + @since("2.0.0") + def computeCost(self, dataset): + """ + Return the K-means cost (sum of squared distances of points to their nearest center) + for this model on the given data. + """ + return self._call_java("computeCost", dataset) + @inherit_doc class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): @@ -53,6 +61,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol >>> centers = model.clusterCenters() >>> len(centers) 2 + >>> model.computeCost(df) + 2.000... >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction |