aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/regression.py7
-rw-r--r--python/pyspark/ml/tuning.py2
2 files changed, 5 insertions, 4 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index da74ab5070..8e76070e9a 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -561,7 +561,7 @@ class TreeRegressorParams(Params):
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
"Supported options: " +
- ", ".join(supportedImpurities))
+ ", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
def __init__(self):
super(TreeRegressorParams, self).__init__()
@@ -1261,11 +1261,12 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
family = Param(Params._dummy(), "family", "The name of family which is a description of " +
"the error distribution to be used in the model. Supported options: " +
- "gaussian(default), binomial, poisson and gamma.")
+ "gaussian(default), binomial, poisson and gamma.",
+ typeConverter=TypeConverters.toString)
link = Param(Params._dummy(), "link", "The name of link function which provides the " +
"relationship between the linear predictor and the mean of the distribution " +
"function. Supported options: identity, log, inverse, logit, probit, cloglog " +
- "and sqrt.")
+ "and sqrt.", typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index ef14da488e..b16628bc70 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -448,7 +448,7 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable):
"""
trainRatio = Param(Params._dummy(), "trainRatio", "Param for ratio between train and\
- validation data. Must be between 0 and 1.")
+ validation data. Must be between 0 and 1.", typeConverter=TypeConverters.toFloat)
@keyword_only
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,