diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-01-29 09:22:24 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-01-29 09:22:24 -0800 |
commit | e51b6eaa9e9c007e194d858195291b2b9fb27322 (patch) | |
tree | b6af90c439154fe7514fd32e47a56a693ffd745a /python/pyspark/ml/tests.py | |
parent | 55561e7693dd2a5bf3c7f8026c725421801fd0ec (diff) | |
download | spark-e51b6eaa9e9c007e194d858195291b2b9fb27322.tar.gz spark-e51b6eaa9e9c007e194d858195291b2b9fb27322.tar.bz2 spark-e51b6eaa9e9c007e194d858195291b2b9fb27322.zip |
[SPARK-13032][ML][PYSPARK] PySpark support model export/import and take LinearRegression as example
* Implement ```MLWriter/MLWritable/MLReader/MLReadable``` for PySpark.
* Making ```LinearRegression``` to support ```save/load``` as example. After this merged, the work for other transformers/estimators will be easy, then we can list and distribute the tasks to the community.
cc mengxr jkbradley
Author: Yanbo Liang <ybliang8@gmail.com>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #10469 from yanboliang/spark-11939.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r-- | python/pyspark/ml/tests.py | 36 |
1 files changed, 30 insertions, 6 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c45a159c46..54806ee336 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -34,18 +34,22 @@ if sys.version_info[:2] <= (2, 6): else: import unittest -from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame, SQLContext, Row -from pyspark.sql.functions import rand +from shutil import rmtree +import tempfile + +from pyspark.ml import Estimator, Model, Pipeline, Transformer from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.feature import * from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed -from pyspark.ml.util import keyword_only -from pyspark.ml import Estimator, Model, Pipeline, Transformer -from pyspark.ml.feature import * +from pyspark.ml.regression import LinearRegression from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel +from pyspark.ml.util import keyword_only from pyspark.mllib.linalg import DenseVector +from pyspark.sql import DataFrame, SQLContext, Row +from pyspark.sql.functions import rand +from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase class MockDataset(DataFrame): @@ -405,6 +409,26 @@ class CrossValidatorTests(PySparkTestCase): self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") +class PersistenceTest(PySparkTestCase): + + def test_linear_regression(self): + lr = LinearRegression(maxIter=1) + path = tempfile.mkdtemp() + lr_path = path + "/lr" + lr.save(lr_path) + lr2 = LinearRegression.load(lr_path) + self.assertEqual(lr2.uid, lr2.maxIter.parent, + "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) + self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], + "Loaded LinearRegression instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: |