aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index ebef656632..36cecd4682 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -466,6 +466,31 @@ class InducedErrorEstimator(Estimator, HasInducedError):
class CrossValidatorTests(PySparkTestCase):
+ def test_copy(self):
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame([
+ (10, 10.0),
+ (50, 50.0),
+ (100, 100.0),
+ (500, 500.0)] * 10,
+ ["feature", "label"])
+
+ iee = InducedErrorEstimator()
+ evaluator = RegressionEvaluator(metricName="rmse")
+
+ grid = (ParamGridBuilder()
+ .addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
+ .build())
+ cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
+ cvCopied = cv.copy()
+ self.assertEqual(cv.getEstimator().uid, cvCopied.getEstimator().uid)
+
+ cvModel = cv.fit(dataset)
+ cvModelCopied = cvModel.copy()
+ for index in range(len(cvModel.avgMetrics)):
+ self.assertTrue(abs(cvModel.avgMetrics[index] - cvModelCopied.avgMetrics[index])
+ < 0.0001)
+
def test_fit_minimize_metric(self):
sqlContext = SQLContext(self.sc)
dataset = sqlContext.createDataFrame([
@@ -539,6 +564,8 @@ class CrossValidatorTests(PySparkTestCase):
cvModel.save(cvModelPath)
loadedModel = CrossValidatorModel.load(cvModelPath)
self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
+ for index in range(len(loadedModel.avgMetrics)):
+ self.assertTrue(abs(loadedModel.avgMetrics[index] - cvModel.avgMetrics[index]) < 0.0001)
class TrainValidationSplitTests(PySparkTestCase):