aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-29 09:22:24 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-29 09:22:24 -0800
commite51b6eaa9e9c007e194d858195291b2b9fb27322 (patch)
treeb6af90c439154fe7514fd32e47a56a693ffd745a /python/pyspark/ml/tests.py
parent55561e7693dd2a5bf3c7f8026c725421801fd0ec (diff)
downloadspark-e51b6eaa9e9c007e194d858195291b2b9fb27322.tar.gz
spark-e51b6eaa9e9c007e194d858195291b2b9fb27322.tar.bz2
spark-e51b6eaa9e9c007e194d858195291b2b9fb27322.zip
[SPARK-13032][ML][PYSPARK] PySpark support model export/import and take LinearRegression as example
* Implement ```MLWriter/MLWritable/MLReader/MLReadable``` for PySpark. * Making ```LinearRegression``` to support ```save/load``` as example. After this merged, the work for other transformers/estimators will be easy, then we can list and distribute the tasks to the community. cc mengxr jkbradley Author: Yanbo Liang <ybliang8@gmail.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #10469 from yanboliang/spark-11939.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py36
1 files changed, 30 insertions, 6 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index c45a159c46..54806ee336 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -34,18 +34,22 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
-from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
-from pyspark.sql import DataFrame, SQLContext, Row
-from pyspark.sql.functions import rand
+from shutil import rmtree
+import tempfile
+
+from pyspark.ml import Estimator, Model, Pipeline, Transformer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import RegressionEvaluator
+from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
-from pyspark.ml.util import keyword_only
-from pyspark.ml import Estimator, Model, Pipeline, Transformer
-from pyspark.ml.feature import *
+from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel
+from pyspark.ml.util import keyword_only
from pyspark.mllib.linalg import DenseVector
+from pyspark.sql import DataFrame, SQLContext, Row
+from pyspark.sql.functions import rand
+from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
class MockDataset(DataFrame):
@@ -405,6 +409,26 @@ class CrossValidatorTests(PySparkTestCase):
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+class PersistenceTest(PySparkTestCase):
+
+ def test_linear_regression(self):
+ lr = LinearRegression(maxIter=1)
+ path = tempfile.mkdtemp()
+ lr_path = path + "/lr"
+ lr.save(lr_path)
+ lr2 = LinearRegression.load(lr_path)
+ self.assertEqual(lr2.uid, lr2.maxIter.parent,
+ "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)"
+ % (lr2.uid, lr2.maxIter.parent))
+ self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter],
+ "Loaded LinearRegression instance default params did not match " +
+ "original defaults")
+ try:
+ rmtree(path)
+ except OSError:
+ pass
+
+
if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner: