diff options
Diffstat (limited to 'python/pyspark/ml/param')
-rw-r--r-- | python/pyspark/ml/param/_shared_params_code_gen.py | 2 | ||||
-rw-r--r-- | python/pyspark/ml/param/shared.py | 24 |
2 files changed, 26 insertions, 0 deletions
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 929591236d..51d49b524c 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -143,6 +143,8 @@ if __name__ == "__main__": "The class with largest value p/t is predicted, where p is the original " + "probability of that class and t is the class's threshold.", None, "TypeConverters.toListFloat"), + ("threshold", "threshold in binary classification prediction, in range [0, 1]", + "0.5", "TypeConverters.toFloat"), ("weightCol", "weight column name. If this is not set or empty, we treat " + "all instance weights as 1.0.", None, "TypeConverters.toString"), ("solver", "the solver algorithm for optimization. If this is not set or empty, " + diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index cc596936d8..163a0e2b3a 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -490,6 +490,30 @@ class HasThresholds(Params): return self.getOrDefault(self.thresholds) +class HasThreshold(Params): + """ + Mixin for param threshold: threshold in binary classification prediction, in range [0, 1] + """ + + threshold = Param(Params._dummy(), "threshold", "threshold in binary classification prediction, in range [0, 1]", typeConverter=TypeConverters.toFloat) + + def __init__(self): + super(HasThreshold, self).__init__() + self._setDefault(threshold=0.5) + + def setThreshold(self, value): + """ + Sets the value of :py:attr:`threshold`. + """ + return self._set(threshold=value) + + def getThreshold(self): + """ + Gets the value of threshold or its default value. + """ + return self.getOrDefault(self.threshold) + + class HasWeightCol(Params): """ Mixin for param weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0. |