diff options
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r-- | python/pyspark/ml/tuning.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 77af0094df..a528d22e18 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -20,7 +20,7 @@ import numpy as np from pyspark import since from pyspark.ml import Estimator, Model -from pyspark.ml.param import Params, Param +from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only from pyspark.sql.functions import rand @@ -121,7 +121,8 @@ class CrossValidator(Estimator, HasSeed): evaluator = Param( Params._dummy(), "evaluator", "evaluator used to select hyper-parameters that maximize the cross-validated metric") - numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") + numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, |