aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tuning.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r--python/pyspark/ml/tuning.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index dcfee6a317..cae778869e 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -223,7 +223,11 @@ class CrossValidator(Estimator):
# TODO: duplicate evaluator to take extra params from input
metric = eva.evaluate(model.transform(validation, epm[j]))
metrics[j] += metric
- bestIndex = np.argmax(metrics)
+
+ if eva.isLargerBetter():
+ bestIndex = np.argmax(metrics)
+ else:
+ bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
return CrossValidatorModel(bestModel)