aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/clustering.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/clustering.py')
-rw-r--r--python/pyspark/ml/clustering.py162
1 files changed, 159 insertions, 3 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 7632f05c3b..e58ec1e7ac 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -17,16 +17,74 @@
from pyspark import since, keyword_only
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
from pyspark.ml.param.shared import *
from pyspark.ml.common import inherit_doc
-__all__ = ['BisectingKMeans', 'BisectingKMeansModel',
+__all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary',
'KMeans', 'KMeansModel',
- 'GaussianMixture', 'GaussianMixtureModel',
+ 'GaussianMixture', 'GaussianMixtureModel', 'GaussianMixtureSummary',
'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel']
+class ClusteringSummary(JavaWrapper):
+ """
+ .. note:: Experimental
+
+ Clustering results for a given model.
+
+ .. versionadded:: 2.1.0
+ """
+
+ @property
+ @since("2.1.0")
+ def predictionCol(self):
+ """
+ Name for column of predicted clusters in `predictions`.
+ """
+ return self._call_java("predictionCol")
+
+ @property
+ @since("2.1.0")
+ def predictions(self):
+ """
+ DataFrame produced by the model's `transform` method.
+ """
+ return self._call_java("predictions")
+
+ @property
+ @since("2.1.0")
+ def featuresCol(self):
+ """
+ Name for column of features in `predictions`.
+ """
+ return self._call_java("featuresCol")
+
+ @property
+ @since("2.1.0")
+ def k(self):
+ """
+ The number of clusters the model was trained with.
+ """
+ return self._call_java("k")
+
+ @property
+ @since("2.1.0")
+ def cluster(self):
+ """
+ DataFrame of predicted cluster centers for each training data point.
+ """
+ return self._call_java("cluster")
+
+ @property
+ @since("2.1.0")
+ def clusterSizes(self):
+ """
+ Size of (number of data points in) each cluster.
+ """
+ return self._call_java("clusterSizes")
+
+
class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental
@@ -56,6 +114,28 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
return self._call_java("gaussiansDF")
+ @property
+ @since("2.1.0")
+ def hasSummary(self):
+ """
+ Indicates whether a training summary exists for this model
+ instance.
+ """
+ return self._call_java("hasSummary")
+
+ @property
+ @since("2.1.0")
+ def summary(self):
+ """
+ Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the
+ training set. An exception is thrown if no summary exists.
+ """
+ if self.hasSummary:
+ return GaussianMixtureSummary(self._call_java("summary"))
+ else:
+ raise RuntimeError("No training summary available for this %s" %
+ self.__class__.__name__)
+
@inherit_doc
class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
@@ -92,6 +172,13 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
>>> gm = GaussianMixture(k=3, tol=0.0001,
... maxIter=10, seed=10)
>>> model = gm.fit(df)
+ >>> model.hasSummary
+ True
+ >>> summary = model.summary
+ >>> summary.k
+ 3
+ >>> summary.clusterSizes
+ [2, 2, 2]
>>> weights = model.weights
>>> len(weights)
3
@@ -118,6 +205,8 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
>>> model_path = temp_path + "/gmm_model"
>>> model.save(model_path)
>>> model2 = GaussianMixtureModel.load(model_path)
+ >>> model2.hasSummary
+ False
>>> model2.weights == model.weights
True
>>> model2.gaussiansDF.show()
@@ -181,6 +270,32 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
return self.getOrDefault(self.k)
+class GaussianMixtureSummary(ClusteringSummary):
+ """
+ .. note:: Experimental
+
+ Gaussian mixture clustering results for a given model.
+
+ .. versionadded:: 2.1.0
+ """
+
+ @property
+ @since("2.1.0")
+ def probabilityCol(self):
+ """
+ Name for column of predicted probability of each cluster in `predictions`.
+ """
+ return self._call_java("probabilityCol")
+
+ @property
+ @since("2.1.0")
+ def probability(self):
+ """
+ DataFrame of probabilities of each cluster for each training data point.
+ """
+ return self._call_java("probability")
+
+
class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by KMeans.
@@ -346,6 +461,27 @@ class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
return self._call_java("computeCost", dataset)
+ @property
+ @since("2.1.0")
+ def hasSummary(self):
+ """
+ Indicates whether a training summary exists for this model instance.
+ """
+ return self._call_java("hasSummary")
+
+ @property
+ @since("2.1.0")
+ def summary(self):
+ """
+ Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the
+ training set. An exception is thrown if no summary exists.
+ """
+ if self.hasSummary:
+ return BisectingKMeansSummary(self._call_java("summary"))
+ else:
+ raise RuntimeError("No training summary available for this %s" %
+ self.__class__.__name__)
+
@inherit_doc
class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed,
@@ -373,6 +509,13 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
2
>>> model.computeCost(df)
2.000...
+ >>> model.hasSummary
+ True
+ >>> summary = model.summary
+ >>> summary.k
+ 2
+ >>> summary.clusterSizes
+ [2, 2]
>>> transformed = model.transform(df).select("features", "prediction")
>>> rows = transformed.collect()
>>> rows[0].prediction == rows[1].prediction
@@ -387,6 +530,8 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
>>> model_path = temp_path + "/bkm_model"
>>> model.save(model_path)
>>> model2 = BisectingKMeansModel.load(model_path)
+ >>> model2.hasSummary
+ False
>>> model.clusterCenters()[0] == model2.clusterCenters()[0]
array([ True, True], dtype=bool)
>>> model.clusterCenters()[1] == model2.clusterCenters()[1]
@@ -460,6 +605,17 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
return BisectingKMeansModel(java_model)
+class BisectingKMeansSummary(ClusteringSummary):
+ """
+ .. note:: Experimental
+
+ Bisecting KMeans clustering results for a given model.
+
+ .. versionadded:: 2.1.0
+ """
+ pass
+
+
@inherit_doc
class LDAModel(JavaModel):
"""