diff options
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r-- | python/pyspark/ml/classification.py | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 8075108114..fdeccf822c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -20,6 +20,7 @@ import warnings from pyspark import since from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param import TypeConverters from pyspark.ml.param.shared import * from pyspark.ml.regression import ( RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels) @@ -87,7 +88,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti threshold = Param(Params._dummy(), "threshold", "Threshold in binary classification prediction, in range [0, 1]." + - " If threshold and thresholds are both set, they must match.") + " If threshold and thresholds are both set, they must match.", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -243,7 +245,7 @@ class TreeClassifierParams(object): impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + "Supported options: " + - ", ".join(supportedImpurities)) + ", ".join(supportedImpurities), typeConverter=TypeConverters.toString) def __init__(self): super(TreeClassifierParams, self).__init__() @@ -534,7 +536,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + - "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) + "Supported options: " + ", ".join(GBTParams.supportedLossTypes), + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -652,9 +655,10 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H """ smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " + - "default is 1.0") + "default is 1.0", typeConverter=TypeConverters.toFloat) modelType = Param(Params._dummy(), "modelType", "The model type which is a string " + - "(case-sensitive). Supported options: multinomial (default) and bernoulli.") + "(case-sensitive). Supported options: multinomial (default) and bernoulli.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -782,11 +786,13 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " + - "neurons and output layer of 10 neurons, default is [1, 1].") + "neurons and output layer of 10 neurons, default is [1, 1].", + typeConverter=TypeConverters.toListInt) blockSize = Param(Params._dummy(), "blockSize", "Block size for stacking input data in " + "matrices. Data is stacked within partitions. If block size is more than " + "remaining data in a partition then it is adjusted to the size of this " + - "data. Recommended size is between 10 and 1000, default is 128.") + "data. Recommended size is between 10 and 1000, default is 128.", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", |