aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/classification.py15
1 files changed, 10 insertions, 5 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index c26c2d7fa5..5c11aa71b4 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -872,7 +872,7 @@ class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable)
@inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
- HasRawPredictionCol, JavaMLWritable, JavaMLReadable):
+ HasRawPredictionCol, HasThresholds, JavaMLWritable, JavaMLReadable):
"""
Naive Bayes Classifiers.
It supports both Multinomial and Bernoulli NB. `Multinomial NB
@@ -918,6 +918,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
True
>>> model.theta == model2.theta
True
+ >>> nb = nb.setThresholds([0.01, 10.00])
+ >>> model3 = nb.fit(df)
+ >>> result = model3.transform(test0).head()
+ >>> result.prediction
+ 0.0
.. versionadded:: 1.5.0
"""
@@ -931,11 +936,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
- modelType="multinomial"):
+ modelType="multinomial", thresholds=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
- modelType="multinomial")
+ modelType="multinomial", thresholds=None)
"""
super(NaiveBayes, self).__init__()
self._java_obj = self._new_java_obj(
@@ -948,11 +953,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
@since("1.5.0")
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
- modelType="multinomial"):
+ modelType="multinomial", thresholds=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
- modelType="multinomial")
+ modelType="multinomial", thresholds=None)
Sets params for Naive Bayes.
"""
kwargs = self.setParams._input_kwargs