aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorGayathriMurali <gayathri.m.softie@gmail.com>2016-03-16 14:21:42 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-16 14:21:42 -0700
commit27e1f38851a8f28a28544b2021b3c5641d0ff3ab (patch)
treeff82e41fadcb181ae134ac4a9313279beba7d9d4 /python/pyspark
parent85c42fda99973a0c35c743816a06ce9117bb1aad (diff)
downloadspark-27e1f38851a8f28a28544b2021b3c5641d0ff3ab.tar.gz
spark-27e1f38851a8f28a28544b2021b3c5641d0ff3ab.tar.bz2
spark-27e1f38851a8f28a28544b2021b3c5641d0ff3ab.zip
[SPARK-13034] PySpark ml.classification support export/import
## What changes were proposed in this pull request? Add export/import for all estimators and transformers(which have Scala implementation) under pyspark/ml/classification.py. ## How was this patch tested? ./python/run-tests ./dev/lint-python Unit tests added to check persistence in Logistic Regression Author: GayathriMurali <gayathri.m.softie@gmail.com> Closes #11707 from GayathriMurali/SPARK-13034.
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/ml/classification.py52
-rw-r--r--python/pyspark/ml/tests.py18
2 files changed, 61 insertions, 9 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index ec8834a89e..16ad76483d 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -18,7 +18,7 @@
import warnings
from pyspark import since
-from pyspark.ml.util import keyword_only
+from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
from pyspark.ml.regression import (
@@ -38,7 +38,7 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel',
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
- HasWeightCol):
+ HasWeightCol, MLWritable, MLReadable):
"""
Logistic regression.
Currently, this class only supports binary classification.
@@ -69,6 +69,18 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
Traceback (most recent call last):
...
TypeError: Method setParams forces keyword arguments.
+ >>> lr_path = temp_path + "/lr"
+ >>> lr.save(lr_path)
+ >>> lr2 = LogisticRegression.load(lr_path)
+ >>> lr2.getMaxIter()
+ 5
+ >>> model_path = temp_path + "/lr_model"
+ >>> model.save(model_path)
+ >>> model2 = LogisticRegressionModel.load(model_path)
+ >>> model.coefficients[0] == model2.coefficients[0]
+ True
+ >>> model.intercept == model2.intercept
+ True
.. versionadded:: 1.3.0
"""
@@ -186,7 +198,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
-class LogisticRegressionModel(JavaModel):
+class LogisticRegressionModel(JavaModel, MLWritable, MLReadable):
"""
Model fitted by LogisticRegression.
@@ -589,7 +601,7 @@ class GBTClassificationModel(TreeEnsembleModels):
@inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
- HasRawPredictionCol):
+ HasRawPredictionCol, MLWritable, MLReadable):
"""
Naive Bayes Classifiers.
It supports both Multinomial and Bernoulli NB. Multinomial NB
@@ -623,6 +635,18 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
>>> model.transform(test1).head().prediction
1.0
+ >>> nb_path = temp_path + "/nb"
+ >>> nb.save(nb_path)
+ >>> nb2 = NaiveBayes.load(nb_path)
+ >>> nb2.getSmoothing()
+ 1.0
+ >>> model_path = temp_path + "/nb_model"
+ >>> model.save(model_path)
+ >>> model2 = NaiveBayesModel.load(model_path)
+ >>> model.pi == model2.pi
+ True
+ >>> model.theta == model2.theta
+ True
.. versionadded:: 1.5.0
"""
@@ -696,7 +720,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
return self.getOrDefault(self.modelType)
-class NaiveBayesModel(JavaModel):
+class NaiveBayesModel(JavaModel, MLWritable, MLReadable):
"""
Model fitted by NaiveBayes.
@@ -853,17 +877,27 @@ class MultilayerPerceptronClassificationModel(JavaModel):
if __name__ == "__main__":
import doctest
+ import pyspark.ml.classification
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
- globs = globals().copy()
+ globs = pyspark.ml.classification.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.classification tests")
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
- (failure_count, test_count) = doctest.testmod(
- globs=globs, optionflags=doctest.ELLIPSIS)
- sc.stop()
+ import tempfile
+ temp_path = tempfile.mkdtemp()
+ globs['temp_path'] = temp_path
+ try:
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ sc.stop()
+ finally:
+ from shutil import rmtree
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
if failure_count:
exit(-1)
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index c76f893e43..9783ce7e77 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -499,6 +499,24 @@ class PersistenceTest(PySparkTestCase):
except OSError:
pass
+ def test_logistic_regression(self):
+ lr = LogisticRegression(maxIter=1)
+ path = tempfile.mkdtemp()
+ lr_path = path + "/logreg"
+ lr.save(lr_path)
+ lr2 = LogisticRegression.load(lr_path)
+ self.assertEqual(lr2.uid, lr2.maxIter.parent,
+ "Loaded LogisticRegression instance uid (%s) "
+ "did not match Param's uid (%s)"
+ % (lr2.uid, lr2.maxIter.parent))
+ self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter],
+ "Loaded LogisticRegression instance default params did not match " +
+ "original defaults")
+ try:
+ rmtree(path)
+ except OSError:
+ pass
+
def test_pipeline_persistence(self):
sqlContext = SQLContext(self.sc)
temp_path = tempfile.mkdtemp()