aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorHolden Karau <holden@us.ibm.com>2016-05-13 08:39:59 +0200
committerNick Pentreath <nick.pentreath@gmail.com>2016-05-13 08:39:59 +0200
commitd1aadea05ab1c7350e46479cc68d08e11916a751 (patch)
tree1fabfff63cd5071c389f028f18d0306749599124 /python
parent51841d77d99a858f8fa1256e923b0364b9b28fa0 (diff)
downloadspark-d1aadea05ab1c7350e46479cc68d08e11916a751.tar.gz
spark-d1aadea05ab1c7350e46479cc68d08e11916a751.tar.bz2
spark-d1aadea05ab1c7350e46479cc68d08e11916a751.zip
[SPARK-15188] Add missing thresholds param to NaiveBayes in PySpark
## What changes were proposed in this pull request? Add missing thresholds param to NiaveBayes ## How was this patch tested? doctests Author: Holden Karau <holden@us.ibm.com> Closes #12963 from holdenk/SPARK-15188-add-missing-naive-bayes-param.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/classification.py15
1 files changed, 10 insertions, 5 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index c26c2d7fa5..5c11aa71b4 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -872,7 +872,7 @@ class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable)
@inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
- HasRawPredictionCol, JavaMLWritable, JavaMLReadable):
+ HasRawPredictionCol, HasThresholds, JavaMLWritable, JavaMLReadable):
"""
Naive Bayes Classifiers.
It supports both Multinomial and Bernoulli NB. `Multinomial NB
@@ -918,6 +918,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
True
>>> model.theta == model2.theta
True
+ >>> nb = nb.setThresholds([0.01, 10.00])
+ >>> model3 = nb.fit(df)
+ >>> result = model3.transform(test0).head()
+ >>> result.prediction
+ 0.0
.. versionadded:: 1.5.0
"""
@@ -931,11 +936,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
- modelType="multinomial"):
+ modelType="multinomial", thresholds=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
- modelType="multinomial")
+ modelType="multinomial", thresholds=None)
"""
super(NaiveBayes, self).__init__()
self._java_obj = self._new_java_obj(
@@ -948,11 +953,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
@since("1.5.0")
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
- modelType="multinomial"):
+ modelType="multinomial", thresholds=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
- modelType="multinomial")
+ modelType="multinomial", thresholds=None)
Sets params for Naive Bayes.
"""
kwargs = self.setParams._input_kwargs