aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-06 10:50:02 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-06 10:50:02 -0800
commit95eb65163391b9e910277a948b72efccf6136e0c (patch)
tree82ad62d5db87ab5921ca448148a8376888ec62c7 /python/pyspark
parent007da1a9dc3bb912da841cc0f5832a4fa28e6d9d (diff)
downloadspark-95eb65163391b9e910277a948b72efccf6136e0c.tar.gz
spark-95eb65163391b9e910277a948b72efccf6136e0c.tar.bz2
spark-95eb65163391b9e910277a948b72efccf6136e0c.zip
[SPARK-11945][ML][PYSPARK] Add computeCost to KMeansModel for PySpark spark.ml
Add ```computeCost``` to ```KMeansModel``` as evaluator for PySpark spark.ml. Author: Yanbo Liang <ybliang8@gmail.com> Closes #9931 from yanboliang/SPARK-11945.
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/ml/clustering.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 7bb8ab94e1..9189c02220 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -36,6 +36,14 @@ class KMeansModel(JavaModel):
"""Get the cluster centers, represented as a list of NumPy arrays."""
return [c.toArray() for c in self._call_java("clusterCenters")]
+ @since("2.0.0")
+ def computeCost(self, dataset):
+ """
+ Return the K-means cost (sum of squared distances of points to their nearest center)
+ for this model on the given data.
+ """
+ return self._call_java("computeCost", dataset)
+
@inherit_doc
class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed):
@@ -53,6 +61,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
>>> centers = model.clusterCenters()
>>> len(centers)
2
+ >>> model.computeCost(df)
+ 2.000...
>>> transformed = model.transform(df).select("features", "prediction")
>>> rows = transformed.collect()
>>> rows[0].prediction == rows[1].prediction