From d1aadea05ab1c7350e46479cc68d08e11916a751 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 13 May 2016 08:39:59 +0200 Subject: [SPARK-15188] Add missing thresholds param to NaiveBayes in PySpark ## What changes were proposed in this pull request? Add missing thresholds param to NiaveBayes ## How was this patch tested? doctests Author: Holden Karau Closes #12963 from holdenk/SPARK-15188-add-missing-naive-bayes-param. --- python/pyspark/ml/classification.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'python') diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index c26c2d7fa5..5c11aa71b4 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -872,7 +872,7 @@ class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable) @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, - HasRawPredictionCol, JavaMLWritable, JavaMLReadable): + HasRawPredictionCol, HasThresholds, JavaMLWritable, JavaMLReadable): """ Naive Bayes Classifiers. It supports both Multinomial and Bernoulli NB. `Multinomial NB @@ -918,6 +918,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H True >>> model.theta == model2.theta True + >>> nb = nb.setThresholds([0.01, 10.00]) + >>> model3 = nb.fit(df) + >>> result = model3.transform(test0).head() + >>> result.prediction + 0.0 .. versionadded:: 1.5.0 """ @@ -931,11 +936,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, - modelType="multinomial"): + modelType="multinomial", thresholds=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ - modelType="multinomial") + modelType="multinomial", thresholds=None) """ super(NaiveBayes, self).__init__() self._java_obj = self._new_java_obj( @@ -948,11 +953,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H @since("1.5.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, - modelType="multinomial"): + modelType="multinomial", thresholds=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ - modelType="multinomial") + modelType="multinomial", thresholds=None) Sets params for Naive Bayes. """ kwargs = self.setParams._input_kwargs -- cgit v1.2.3