diff options
author | Kai Jiang <jiangkai@gmail.com> | 2016-04-28 14:19:11 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-28 14:19:11 -0700 |
commit | d584a2b8ac57eff3bf230c760e5bda205c6ea747 (patch) | |
tree | 40b7f2992794445dfcea149edad57749531b856c /python/pyspark/ml/tests.py | |
parent | 0ee5419b6ce535c714718d0d33b80eedd4b0a5fd (diff) | |
download | spark-d584a2b8ac57eff3bf230c760e5bda205c6ea747.tar.gz spark-d584a2b8ac57eff3bf230c760e5bda205c6ea747.tar.bz2 spark-d584a2b8ac57eff3bf230c760e5bda205c6ea747.zip |
[SPARK-12810][PYSPARK] PySpark CrossValidatorModel should support avgMetrics
## What changes were proposed in this pull request?
support avgMetrics in CrossValidatorModel with Python
## How was this patch tested?
Doctest and `test_save_load` in `pyspark/ml/test.py`
[JIRA](https://issues.apache.org/jira/browse/SPARK-12810)
Author: Kai Jiang <jiangkai@gmail.com>
Closes #12464 from vectorijk/spark-12810.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r-- | python/pyspark/ml/tests.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index ebef656632..36cecd4682 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -466,6 +466,31 @@ class InducedErrorEstimator(Estimator, HasInducedError): class CrossValidatorTests(PySparkTestCase): + def test_copy(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="rmse") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvCopied = cv.copy() + self.assertEqual(cv.getEstimator().uid, cvCopied.getEstimator().uid) + + cvModel = cv.fit(dataset) + cvModelCopied = cvModel.copy() + for index in range(len(cvModel.avgMetrics)): + self.assertTrue(abs(cvModel.avgMetrics[index] - cvModelCopied.avgMetrics[index]) + < 0.0001) + def test_fit_minimize_metric(self): sqlContext = SQLContext(self.sc) dataset = sqlContext.createDataFrame([ @@ -539,6 +564,8 @@ class CrossValidatorTests(PySparkTestCase): cvModel.save(cvModelPath) loadedModel = CrossValidatorModel.load(cvModelPath) self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + for index in range(len(loadedModel.avgMetrics)): + self.assertTrue(abs(loadedModel.avgMetrics[index] - cvModel.avgMetrics[index]) < 0.0001) class TrainValidationSplitTests(PySparkTestCase): |