aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py20
1 files changed, 13 insertions, 7 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 8075108114..fdeccf822c 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -20,6 +20,7 @@ import warnings
from pyspark import since
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.param import TypeConverters
from pyspark.ml.param.shared import *
from pyspark.ml.regression import (
RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels)
@@ -87,7 +88,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
threshold = Param(Params._dummy(), "threshold",
"Threshold in binary classification prediction, in range [0, 1]." +
- " If threshold and thresholds are both set, they must match.")
+ " If threshold and thresholds are both set, they must match.",
+ typeConverter=TypeConverters.toFloat)
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
@@ -243,7 +245,7 @@ class TreeClassifierParams(object):
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
"Supported options: " +
- ", ".join(supportedImpurities))
+ ", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
def __init__(self):
super(TreeClassifierParams, self).__init__()
@@ -534,7 +536,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
lossType = Param(Params._dummy(), "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
- "Supported options: " + ", ".join(GBTParams.supportedLossTypes))
+ "Supported options: " + ", ".join(GBTParams.supportedLossTypes),
+ typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
@@ -652,9 +655,10 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
"""
smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " +
- "default is 1.0")
+ "default is 1.0", typeConverter=TypeConverters.toFloat)
modelType = Param(Params._dummy(), "modelType", "The model type which is a string " +
- "(case-sensitive). Supported options: multinomial (default) and bernoulli.")
+ "(case-sensitive). Supported options: multinomial (default) and bernoulli.",
+ typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
@@ -782,11 +786,13 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " +
"E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " +
- "neurons and output layer of 10 neurons, default is [1, 1].")
+ "neurons and output layer of 10 neurons, default is [1, 1].",
+ typeConverter=TypeConverters.toListInt)
blockSize = Param(Params._dummy(), "blockSize", "Block size for stacking input data in " +
"matrices. Data is stacked within partitions. If block size is more than " +
"remaining data in a partition then it is adjusted to the size of this " +
- "data. Recommended size is between 10 and 1000, default is 128.")
+ "data. Recommended size is between 10 and 1000, default is 128.",
+ typeConverter=TypeConverters.toInt)
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",