From 27e1f38851a8f28a28544b2021b3c5641d0ff3ab Mon Sep 17 00:00:00 2001 From: GayathriMurali Date: Wed, 16 Mar 2016 14:21:42 -0700 Subject: [SPARK-13034] PySpark ml.classification support export/import ## What changes were proposed in this pull request? Add export/import for all estimators and transformers(which have Scala implementation) under pyspark/ml/classification.py. ## How was this patch tested? ./python/run-tests ./dev/lint-python Unit tests added to check persistence in Logistic Regression Author: GayathriMurali Closes #11707 from GayathriMurali/SPARK-13034. --- python/pyspark/ml/classification.py | 52 ++++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 9 deletions(-) (limited to 'python/pyspark/ml/classification.py') diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ec8834a89e..16ad76483d 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -18,7 +18,7 @@ import warnings from pyspark import since -from pyspark.ml.util import keyword_only +from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.ml.regression import ( @@ -38,7 +38,7 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds, - HasWeightCol): + HasWeightCol, MLWritable, MLReadable): """ Logistic regression. Currently, this class only supports binary classification. @@ -69,6 +69,18 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + >>> lr_path = temp_path + "/lr" + >>> lr.save(lr_path) + >>> lr2 = LogisticRegression.load(lr_path) + >>> lr2.getMaxIter() + 5 + >>> model_path = temp_path + "/lr_model" + >>> model.save(model_path) + >>> model2 = LogisticRegressionModel.load(model_path) + >>> model.coefficients[0] == model2.coefficients[0] + True + >>> model.intercept == model2.intercept + True .. versionadded:: 1.3.0 """ @@ -186,7 +198,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) -class LogisticRegressionModel(JavaModel): +class LogisticRegressionModel(JavaModel, MLWritable, MLReadable): """ Model fitted by LogisticRegression. @@ -589,7 +601,7 @@ class GBTClassificationModel(TreeEnsembleModels): @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, - HasRawPredictionCol): + HasRawPredictionCol, MLWritable, MLReadable): """ Naive Bayes Classifiers. It supports both Multinomial and Bernoulli NB. Multinomial NB @@ -623,6 +635,18 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 + >>> nb_path = temp_path + "/nb" + >>> nb.save(nb_path) + >>> nb2 = NaiveBayes.load(nb_path) + >>> nb2.getSmoothing() + 1.0 + >>> model_path = temp_path + "/nb_model" + >>> model.save(model_path) + >>> model2 = NaiveBayesModel.load(model_path) + >>> model.pi == model2.pi + True + >>> model.theta == model2.theta + True .. versionadded:: 1.5.0 """ @@ -696,7 +720,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H return self.getOrDefault(self.modelType) -class NaiveBayesModel(JavaModel): +class NaiveBayesModel(JavaModel, MLWritable, MLReadable): """ Model fitted by NaiveBayes. @@ -853,17 +877,27 @@ class MultilayerPerceptronClassificationModel(JavaModel): if __name__ == "__main__": import doctest + import pyspark.ml.classification from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.classification.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.classification tests") sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + import tempfile + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) -- cgit v1.2.3