aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tuning.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r--python/pyspark/ml/tuning.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 77af0094df..a528d22e18 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -20,7 +20,7 @@ import numpy as np
from pyspark import since
from pyspark.ml import Estimator, Model
-from pyspark.ml.param import Params, Param
+from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasSeed
from pyspark.ml.util import keyword_only
from pyspark.sql.functions import rand
@@ -121,7 +121,8 @@ class CrossValidator(Estimator, HasSeed):
evaluator = Param(
Params._dummy(), "evaluator",
"evaluator used to select hyper-parameters that maximize the cross-validated metric")
- numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")
+ numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation",
+ typeConverter=TypeConverters.toInt)
@keyword_only
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,