aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-04-06 11:24:11 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-06 11:24:11 -0700
commitdb0b06c6ea7412266158b1c710bdc8ca30e26430 (patch)
tree58c218ecdbe61927b7f9c3addf11b0bf245ffb2a /python/pyspark/ml/tests.py
parent3c8d8821654e3d82ef927c55272348e1bcc34a79 (diff)
downloadspark-db0b06c6ea7412266158b1c710bdc8ca30e26430.tar.gz
spark-db0b06c6ea7412266158b1c710bdc8ca30e26430.tar.bz2
spark-db0b06c6ea7412266158b1c710bdc8ca30e26430.zip
[SPARK-13786][ML][PYSPARK] Add save/load for pyspark.ml.tuning
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13786 Add save/load for Python CrossValidator/Model and TrainValidationSplit/Model. ## How was this patch tested? Test with Python doctest. Author: Xusen Yin <yinxusen@gmail.com> Closes #12020 from yinxusen/SPARK-13786.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py56
1 files changed, 54 insertions, 2 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index f6159b2c95..e3f873e3a7 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -44,7 +44,7 @@ import numpy as np
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier
from pyspark.ml.clustering import KMeans
-from pyspark.ml.evaluation import RegressionEvaluator
+from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
@@ -53,7 +53,7 @@ from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
from pyspark.ml.util import MLWritable, MLWriter
from pyspark.ml.wrapper import JavaWrapper
-from pyspark.mllib.linalg import DenseVector, SparseVector
+from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@@ -479,6 +479,32 @@ class CrossValidatorTests(PySparkTestCase):
"Best model should have zero induced error")
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+ def test_save_load(self):
+ temp_path = tempfile.mkdtemp()
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ cvModel = cv.fit(dataset)
+ cvPath = temp_path + "/cv"
+ cv.save(cvPath)
+ loadedCV = CrossValidator.load(cvPath)
+ self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
+ self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
+ self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
+ cvModelPath = temp_path + "/cvModel"
+ cvModel.save(cvModelPath)
+ loadedModel = CrossValidatorModel.load(cvModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
+
class TrainValidationSplitTests(PySparkTestCase):
@@ -530,6 +556,32 @@ class TrainValidationSplitTests(PySparkTestCase):
"Best model should have zero induced error")
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+ def test_save_load(self):
+ temp_path = tempfile.mkdtemp()
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ tvsModel = tvs.fit(dataset)
+ tvsPath = temp_path + "/tvs"
+ tvs.save(tvsPath)
+ loadedTvs = TrainValidationSplit.load(tvsPath)
+ self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
+ self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
+ self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
+ tvsModelPath = temp_path + "/tvsModel"
+ tvsModel.save(tvsModelPath)
+ loadedModel = TrainValidationSplitModel.load(tvsModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
+
class PersistenceTest(PySparkTestCase):