aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py53
1 files changed, 52 insertions, 1 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 5fcfa9e61f..8182fcfb4e 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -45,7 +45,7 @@ from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
from pyspark.ml.regression import LinearRegression
-from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel
+from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
from pyspark.mllib.linalg import DenseVector
from pyspark.sql import DataFrame, SQLContext, Row
@@ -423,6 +423,57 @@ class CrossValidatorTests(PySparkTestCase):
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+class TrainValidationSplitTests(PySparkTestCase):
+
+ def test_fit_minimize_metric(self):
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame([
+ (10, 10.0),
+ (50, 50.0),
+ (100, 100.0),
+ (500, 500.0)] * 10,
+ ["feature", "label"])
+
+ iee = InducedErrorEstimator()
+ evaluator = RegressionEvaluator(metricName="rmse")
+
+ grid = (ParamGridBuilder()
+ .addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
+ .build())
+ tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
+ tvsModel = tvs.fit(dataset)
+ bestModel = tvsModel.bestModel
+ bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
+
+ self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
+ "Best model should have zero induced error")
+ self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0")
+
+ def test_fit_maximize_metric(self):
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame([
+ (10, 10.0),
+ (50, 50.0),
+ (100, 100.0),
+ (500, 500.0)] * 10,
+ ["feature", "label"])
+
+ iee = InducedErrorEstimator()
+ evaluator = RegressionEvaluator(metricName="r2")
+
+ grid = (ParamGridBuilder()
+ .addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
+ .build())
+ tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
+ tvsModel = tvs.fit(dataset)
+ bestModel = tvsModel.bestModel
+ bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
+
+ self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
+ "Best model should have zero induced error")
+ self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+
+
class PersistenceTest(PySparkTestCase):
def test_linear_regression(self):