From db0b06c6ea7412266158b1c710bdc8ca30e26430 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 6 Apr 2016 11:24:11 -0700 Subject: [SPARK-13786][ML][PYSPARK] Add save/load for pyspark.ml.tuning ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13786 Add save/load for Python CrossValidator/Model and TrainValidationSplit/Model. ## How was this patch tested? Test with Python doctest. Author: Xusen Yin Closes #12020 from yinxusen/SPARK-13786. --- python/pyspark/ml/tests.py | 56 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 2 deletions(-) (limited to 'python/pyspark/ml/tests.py') diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index f6159b2c95..e3f873e3a7 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -44,7 +44,7 @@ import numpy as np from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier from pyspark.ml.clustering import KMeans -from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed @@ -53,7 +53,7 @@ from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only from pyspark.ml.util import MLWritable, MLWriter from pyspark.ml.wrapper import JavaWrapper -from pyspark.mllib.linalg import DenseVector, SparseVector +from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -479,6 +479,32 @@ class CrossValidatorTests(PySparkTestCase): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + def test_save_load(self): + temp_path = tempfile.mkdtemp() + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + cvPath = temp_path + "/cv" + cv.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) + self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) + self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) + cvModelPath = temp_path + "/cvModel" + cvModel.save(cvModelPath) + loadedModel = CrossValidatorModel.load(cvModelPath) + self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + class TrainValidationSplitTests(PySparkTestCase): @@ -530,6 +556,32 @@ class TrainValidationSplitTests(PySparkTestCase): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + def test_save_load(self): + temp_path = tempfile.mkdtemp() + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) + self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) + self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) + tvsModelPath = temp_path + "/tvsModel" + tvsModel.save(tvsModelPath) + loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + class PersistenceTest(PySparkTestCase): -- cgit v1.2.3