diff options
author | GayathriMurali <gayathri.m.softie@gmail.com> | 2016-03-16 14:21:42 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-03-16 14:21:42 -0700 |
commit | 27e1f38851a8f28a28544b2021b3c5641d0ff3ab (patch) | |
tree | ff82e41fadcb181ae134ac4a9313279beba7d9d4 /python/pyspark/ml/tests.py | |
parent | 85c42fda99973a0c35c743816a06ce9117bb1aad (diff) | |
download | spark-27e1f38851a8f28a28544b2021b3c5641d0ff3ab.tar.gz spark-27e1f38851a8f28a28544b2021b3c5641d0ff3ab.tar.bz2 spark-27e1f38851a8f28a28544b2021b3c5641d0ff3ab.zip |
[SPARK-13034] PySpark ml.classification support export/import
## What changes were proposed in this pull request?
Add export/import for all estimators and transformers(which have Scala implementation) under pyspark/ml/classification.py.
## How was this patch tested?
./python/run-tests
./dev/lint-python
Unit tests added to check persistence in Logistic Regression
Author: GayathriMurali <gayathri.m.softie@gmail.com>
Closes #11707 from GayathriMurali/SPARK-13034.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r-- | python/pyspark/ml/tests.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c76f893e43..9783ce7e77 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -499,6 +499,24 @@ class PersistenceTest(PySparkTestCase): except OSError: pass + def test_logistic_regression(self): + lr = LogisticRegression(maxIter=1) + path = tempfile.mkdtemp() + lr_path = path + "/logreg" + lr.save(lr_path) + lr2 = LogisticRegression.load(lr_path) + self.assertEqual(lr2.uid, lr2.maxIter.parent, + "Loaded LogisticRegression instance uid (%s) " + "did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) + self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], + "Loaded LogisticRegression instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + def test_pipeline_persistence(self): sqlContext = SQLContext(self.sc) temp_path = tempfile.mkdtemp() |