diff options
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r-- | python/pyspark/ml/tests.py | 56 |
1 files changed, 54 insertions, 2 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index f6159b2c95..e3f873e3a7 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -44,7 +44,7 @@ import numpy as np from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier from pyspark.ml.clustering import KMeans -from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed @@ -53,7 +53,7 @@ from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only from pyspark.ml.util import MLWritable, MLWriter from pyspark.ml.wrapper import JavaWrapper -from pyspark.mllib.linalg import DenseVector, SparseVector +from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -479,6 +479,32 @@ class CrossValidatorTests(PySparkTestCase): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + def test_save_load(self): + temp_path = tempfile.mkdtemp() + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + cvPath = temp_path + "/cv" + cv.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) + self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) + self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) + cvModelPath = temp_path + "/cvModel" + cvModel.save(cvModelPath) + loadedModel = CrossValidatorModel.load(cvModelPath) + self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + class TrainValidationSplitTests(PySparkTestCase): @@ -530,6 +556,32 @@ class TrainValidationSplitTests(PySparkTestCase): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + def test_save_load(self): + temp_path = tempfile.mkdtemp() + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) + self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) + self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) + tvsModelPath = temp_path + "/tvsModel" + tvsModel.save(tvsModelPath) + loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + class PersistenceTest(PySparkTestCase): |