aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorGayathriMurali <gayathri.m.softie@gmail.com>2016-03-16 14:21:42 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-16 14:21:42 -0700
commit27e1f38851a8f28a28544b2021b3c5641d0ff3ab (patch)
treeff82e41fadcb181ae134ac4a9313279beba7d9d4 /python/pyspark/ml/tests.py
parent85c42fda99973a0c35c743816a06ce9117bb1aad (diff)
downloadspark-27e1f38851a8f28a28544b2021b3c5641d0ff3ab.tar.gz
spark-27e1f38851a8f28a28544b2021b3c5641d0ff3ab.tar.bz2
spark-27e1f38851a8f28a28544b2021b3c5641d0ff3ab.zip
[SPARK-13034] PySpark ml.classification support export/import
## What changes were proposed in this pull request? Add export/import for all estimators and transformers(which have Scala implementation) under pyspark/ml/classification.py. ## How was this patch tested? ./python/run-tests ./dev/lint-python Unit tests added to check persistence in Logistic Regression Author: GayathriMurali <gayathri.m.softie@gmail.com> Closes #11707 from GayathriMurali/SPARK-13034.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index c76f893e43..9783ce7e77 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -499,6 +499,24 @@ class PersistenceTest(PySparkTestCase):
except OSError:
pass
+ def test_logistic_regression(self):
+ lr = LogisticRegression(maxIter=1)
+ path = tempfile.mkdtemp()
+ lr_path = path + "/logreg"
+ lr.save(lr_path)
+ lr2 = LogisticRegression.load(lr_path)
+ self.assertEqual(lr2.uid, lr2.maxIter.parent,
+ "Loaded LogisticRegression instance uid (%s) "
+ "did not match Param's uid (%s)"
+ % (lr2.uid, lr2.maxIter.parent))
+ self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter],
+ "Loaded LogisticRegression instance default params did not match " +
+ "original defaults")
+ try:
+ rmtree(path)
+ except OSError:
+ pass
+
def test_pipeline_persistence(self):
sqlContext = SQLContext(self.sc)
temp_path = tempfile.mkdtemp()