aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py56
1 files changed, 54 insertions, 2 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index f6159b2c95..e3f873e3a7 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -44,7 +44,7 @@ import numpy as np
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier
from pyspark.ml.clustering import KMeans
-from pyspark.ml.evaluation import RegressionEvaluator
+from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
@@ -53,7 +53,7 @@ from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
from pyspark.ml.util import MLWritable, MLWriter
from pyspark.ml.wrapper import JavaWrapper
-from pyspark.mllib.linalg import DenseVector, SparseVector
+from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@@ -479,6 +479,32 @@ class CrossValidatorTests(PySparkTestCase):
"Best model should have zero induced error")
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+ def test_save_load(self):
+ temp_path = tempfile.mkdtemp()
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ cvModel = cv.fit(dataset)
+ cvPath = temp_path + "/cv"
+ cv.save(cvPath)
+ loadedCV = CrossValidator.load(cvPath)
+ self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
+ self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
+ self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
+ cvModelPath = temp_path + "/cvModel"
+ cvModel.save(cvModelPath)
+ loadedModel = CrossValidatorModel.load(cvModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
+
class TrainValidationSplitTests(PySparkTestCase):
@@ -530,6 +556,32 @@ class TrainValidationSplitTests(PySparkTestCase):
"Best model should have zero induced error")
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+ def test_save_load(self):
+ temp_path = tempfile.mkdtemp()
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ tvsModel = tvs.fit(dataset)
+ tvsPath = temp_path + "/tvs"
+ tvs.save(tvsPath)
+ loadedTvs = TrainValidationSplit.load(tvsPath)
+ self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
+ self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
+ self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
+ tvsModelPath = temp_path + "/tvsModel"
+ tvsModel.save(tvsModelPath)
+ loadedModel = TrainValidationSplitModel.load(tvsModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
+
class PersistenceTest(PySparkTestCase):