aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/clustering.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-02-11 15:55:40 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-11 15:55:40 -0800
commit30e00955663dfe6079effe4465bbcecedb5d93b9 (patch)
tree302e7e592525fc4d8f709b43973378f5fd8a0179 /python/pyspark/ml/clustering.py
parent2426eb3e167fece19831070594247e9481dbbe2a (diff)
downloadspark-30e00955663dfe6079effe4465bbcecedb5d93b9.tar.gz
spark-30e00955663dfe6079effe4465bbcecedb5d93b9.tar.bz2
spark-30e00955663dfe6079effe4465bbcecedb5d93b9.zip
[SPARK-13035][ML][PYSPARK] PySpark ml.clustering support export/import
PySpark ml.clustering support export/import. Author: Yanbo Liang <ybliang8@gmail.com> Closes #10999 from yanboliang/spark-13035.
Diffstat (limited to 'python/pyspark/ml/clustering.py')
-rw-r--r--python/pyspark/ml/clustering.py29
1 files changed, 25 insertions, 4 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 12afb88563..f156eda125 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -16,7 +16,7 @@
#
from pyspark import since
-from pyspark.ml.util import keyword_only
+from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
from pyspark.mllib.common import inherit_doc
@@ -24,7 +24,7 @@ from pyspark.mllib.common import inherit_doc
__all__ = ['KMeans', 'KMeansModel']
-class KMeansModel(JavaModel):
+class KMeansModel(JavaModel, MLWritable, MLReadable):
"""
Model fitted by KMeans.
@@ -46,7 +46,8 @@ class KMeansModel(JavaModel):
@inherit_doc
-class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed):
+class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
+ MLWritable, MLReadable):
"""
K-means clustering with support for multiple parallel runs and a k-means++ like initialization
mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
@@ -69,6 +70,25 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
True
>>> rows[2].prediction == rows[3].prediction
True
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> kmeans_path = path + "/kmeans"
+ >>> kmeans.save(kmeans_path)
+ >>> kmeans2 = KMeans.load(kmeans_path)
+ >>> kmeans2.getK()
+ 2
+ >>> model_path = path + "/kmeans_model"
+ >>> model.save(model_path)
+ >>> model2 = KMeansModel.load(model_path)
+ >>> model.clusterCenters()[0] == model2.clusterCenters()[0]
+ array([ True, True], dtype=bool)
+ >>> model.clusterCenters()[1] == model2.clusterCenters()[1]
+ array([ True, True], dtype=bool)
+ >>> from shutil import rmtree
+ >>> try:
+ ... rmtree(path)
+ ... except OSError:
+ ... pass
.. versionadded:: 1.5.0
"""
@@ -157,9 +177,10 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
if __name__ == "__main__":
import doctest
+ import pyspark.ml.clustering
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
- globs = globals().copy()
+ globs = pyspark.ml.clustering.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.clustering tests")