diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-02-11 15:55:40 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-02-11 15:55:40 -0800 |
commit | 30e00955663dfe6079effe4465bbcecedb5d93b9 (patch) | |
tree | 302e7e592525fc4d8f709b43973378f5fd8a0179 /python | |
parent | 2426eb3e167fece19831070594247e9481dbbe2a (diff) | |
download | spark-30e00955663dfe6079effe4465bbcecedb5d93b9.tar.gz spark-30e00955663dfe6079effe4465bbcecedb5d93b9.tar.bz2 spark-30e00955663dfe6079effe4465bbcecedb5d93b9.zip |
[SPARK-13035][ML][PYSPARK] PySpark ml.clustering support export/import
PySpark ml.clustering support export/import.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #10999 from yanboliang/spark-13035.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/ml/clustering.py | 29 |
1 files changed, 25 insertions, 4 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 12afb88563..f156eda125 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -16,7 +16,7 @@ # from pyspark import since -from pyspark.ml.util import keyword_only +from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc @@ -24,7 +24,7 @@ from pyspark.mllib.common import inherit_doc __all__ = ['KMeans', 'KMeansModel'] -class KMeansModel(JavaModel): +class KMeansModel(JavaModel, MLWritable, MLReadable): """ Model fitted by KMeans. @@ -46,7 +46,8 @@ class KMeansModel(JavaModel): @inherit_doc -class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): +class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, + MLWritable, MLReadable): """ K-means clustering with support for multiple parallel runs and a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, @@ -69,6 +70,25 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol True >>> rows[2].prediction == rows[3].prediction True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> kmeans_path = path + "/kmeans" + >>> kmeans.save(kmeans_path) + >>> kmeans2 = KMeans.load(kmeans_path) + >>> kmeans2.getK() + 2 + >>> model_path = path + "/kmeans_model" + >>> model.save(model_path) + >>> model2 = KMeansModel.load(model_path) + >>> model.clusterCenters()[0] == model2.clusterCenters()[0] + array([ True, True], dtype=bool) + >>> model.clusterCenters()[1] == model2.clusterCenters()[1] + array([ True, True], dtype=bool) + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass .. versionadded:: 1.5.0 """ @@ -157,9 +177,10 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol if __name__ == "__main__": import doctest + import pyspark.ml.clustering from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.clustering.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.clustering tests") |