diff options
-rw-r--r-- | python/pyspark/ml/tuning.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 7f967e5463..2dcc99cef8 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -166,6 +166,8 @@ class CrossValidator(Estimator, ValidatorParams): >>> evaluator = BinaryClassificationEvaluator() >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) >>> cvModel = cv.fit(dataset) + >>> cvModel.avgMetrics[0] + 0.5 >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... @@ -234,7 +236,7 @@ class CrossValidator(Estimator, ValidatorParams): model = est.fit(train, epm[j]) # TODO: duplicate evaluator to take extra params from input metric = eva.evaluate(model.transform(validation, epm[j])) - metrics[j] += metric + metrics[j] += metric/nFolds if eva.isLargerBetter(): bestIndex = np.argmax(metrics) |