diff options
author | Holden Karau <holden@us.ibm.com> | 2016-05-13 08:39:59 +0200 |
---|---|---|
committer | Nick Pentreath <nick.pentreath@gmail.com> | 2016-05-13 08:39:59 +0200 |
commit | d1aadea05ab1c7350e46479cc68d08e11916a751 (patch) | |
tree | 1fabfff63cd5071c389f028f18d0306749599124 | |
parent | 51841d77d99a858f8fa1256e923b0364b9b28fa0 (diff) | |
download | spark-d1aadea05ab1c7350e46479cc68d08e11916a751.tar.gz spark-d1aadea05ab1c7350e46479cc68d08e11916a751.tar.bz2 spark-d1aadea05ab1c7350e46479cc68d08e11916a751.zip |
[SPARK-15188] Add missing thresholds param to NaiveBayes in PySpark
## What changes were proposed in this pull request?
Add missing thresholds param to NiaveBayes
## How was this patch tested?
doctests
Author: Holden Karau <holden@us.ibm.com>
Closes #12963 from holdenk/SPARK-15188-add-missing-naive-bayes-param.
-rw-r--r-- | python/pyspark/ml/classification.py | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index c26c2d7fa5..5c11aa71b4 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -872,7 +872,7 @@ class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable) @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, - HasRawPredictionCol, JavaMLWritable, JavaMLReadable): + HasRawPredictionCol, HasThresholds, JavaMLWritable, JavaMLReadable): """ Naive Bayes Classifiers. It supports both Multinomial and Bernoulli NB. `Multinomial NB @@ -918,6 +918,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H True >>> model.theta == model2.theta True + >>> nb = nb.setThresholds([0.01, 10.00]) + >>> model3 = nb.fit(df) + >>> result = model3.transform(test0).head() + >>> result.prediction + 0.0 .. versionadded:: 1.5.0 """ @@ -931,11 +936,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, - modelType="multinomial"): + modelType="multinomial", thresholds=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ - modelType="multinomial") + modelType="multinomial", thresholds=None) """ super(NaiveBayes, self).__init__() self._java_obj = self._new_java_obj( @@ -948,11 +953,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H @since("1.5.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, - modelType="multinomial"): + modelType="multinomial", thresholds=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ - modelType="multinomial") + modelType="multinomial", thresholds=None) Sets params for Naive Bayes. """ kwargs = self.setParams._input_kwargs |