aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/tuning.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 7f967e5463..2dcc99cef8 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -166,6 +166,8 @@ class CrossValidator(Estimator, ValidatorParams):
>>> evaluator = BinaryClassificationEvaluator()
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
>>> cvModel = cv.fit(dataset)
+ >>> cvModel.avgMetrics[0]
+ 0.5
>>> evaluator.evaluate(cvModel.transform(dataset))
0.8333...
@@ -234,7 +236,7 @@ class CrossValidator(Estimator, ValidatorParams):
model = est.fit(train, epm[j])
# TODO: duplicate evaluator to take extra params from input
metric = eva.evaluate(model.transform(validation, epm[j]))
- metrics[j] += metric
+ metrics[j] += metric/nFolds
if eva.isLargerBetter():
bestIndex = np.argmax(metrics)