From 30e00955663dfe6079effe4465bbcecedb5d93b9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 11 Feb 2016 15:55:40 -0800 Subject: [SPARK-13035][ML][PYSPARK] PySpark ml.clustering support export/import PySpark ml.clustering support export/import. Author: Yanbo Liang Closes #10999 from yanboliang/spark-13035. --- python/pyspark/ml/clustering.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) (limited to 'python/pyspark/ml/clustering.py') diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 12afb88563..f156eda125 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -16,7 +16,7 @@ # from pyspark import since -from pyspark.ml.util import keyword_only +from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc @@ -24,7 +24,7 @@ from pyspark.mllib.common import inherit_doc __all__ = ['KMeans', 'KMeansModel'] -class KMeansModel(JavaModel): +class KMeansModel(JavaModel, MLWritable, MLReadable): """ Model fitted by KMeans. @@ -46,7 +46,8 @@ class KMeansModel(JavaModel): @inherit_doc -class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): +class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, + MLWritable, MLReadable): """ K-means clustering with support for multiple parallel runs and a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, @@ -69,6 +70,25 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol True >>> rows[2].prediction == rows[3].prediction True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> kmeans_path = path + "/kmeans" + >>> kmeans.save(kmeans_path) + >>> kmeans2 = KMeans.load(kmeans_path) + >>> kmeans2.getK() + 2 + >>> model_path = path + "/kmeans_model" + >>> model.save(model_path) + >>> model2 = KMeansModel.load(model_path) + >>> model.clusterCenters()[0] == model2.clusterCenters()[0] + array([ True, True], dtype=bool) + >>> model.clusterCenters()[1] == model2.clusterCenters()[1] + array([ True, True], dtype=bool) + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass .. versionadded:: 1.5.0 """ @@ -157,9 +177,10 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol if __name__ == "__main__": import doctest + import pyspark.ml.clustering from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.clustering.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.clustering tests") -- cgit v1.2.3