aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-12 14:27:13 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-12 14:27:13 -0700
commit551def5d6972440365bd7436d484a67138d9a8f3 (patch)
treeaf2280c3849497b4236099ec84fe7b4b64d63f2e /python/pyspark/ml/classification.py
parent762bacc16ac5e74c8b05a7c1e3e367d1d1633cef (diff)
downloadspark-551def5d6972440365bd7436d484a67138d9a8f3.tar.gz
spark-551def5d6972440365bd7436d484a67138d9a8f3.tar.bz2
spark-551def5d6972440365bd7436d484a67138d9a8f3.zip
[SPARK-9789] [ML] Added logreg threshold param back
Reinstated LogisticRegression.threshold Param for binary compatibility. Param thresholds overrides threshold, if set. CC: mengxr dbtsai feynmanliang Author: Joseph K. Bradley <joseph@databricks.com> Closes #8079 from jkbradley/logreg-reinstate-threshold.
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py98
1 files changed, 62 insertions, 36 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 6702dce554..83f808efc3 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -76,19 +76,21 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
" Array must have length equal to the number of classes, with values >= 0." +
" The class with largest value p/t is predicted, where p is the original" +
" probability of that class and t is the class' threshold.")
+ threshold = Param(Params._dummy(), "threshold",
+ "Threshold in binary classification prediction, in range [0, 1]." +
+ " If threshold and thresholds are both set, they must match.")
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- threshold=None, thresholds=None,
+ threshold=0.5, thresholds=None,
probabilityCol="probability", rawPredictionCol="rawPrediction"):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- threshold=None, thresholds=None, \
+ threshold=0.5, thresholds=None, \
probabilityCol="probability", rawPredictionCol="rawPrediction")
- Param thresholds overrides Param threshold; threshold is provided
- for backwards compatibility and only applies to binary classification.
+ If the threshold and thresholds Params are both set, they must be equivalent.
"""
super(LogisticRegression, self).__init__()
self._java_obj = self._new_java_obj(
@@ -101,7 +103,11 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
"the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
#: param for whether to fit an intercept term.
self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.")
- #: param for threshold in binary classification prediction, in range [0, 1].
+ #: param for threshold in binary classification, in range [0, 1].
+ self.threshold = Param(self, "threshold",
+ "Threshold in binary classification prediction, in range [0, 1]." +
+ " If threshold and thresholds are both set, they must match.")
+ #: param for thresholds or cutoffs in binary or multiclass classification
self.thresholds = \
Param(self, "thresholds",
"Thresholds in multi-class classification" +
@@ -110,29 +116,28 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
" The class with largest value p/t is predicted, where p is the original" +
" probability of that class and t is the class' threshold.")
self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6,
- fitIntercept=True)
+ fitIntercept=True, threshold=0.5)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
+ self._checkThresholdConsistency()
@keyword_only
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- threshold=None, thresholds=None,
+ threshold=0.5, thresholds=None,
probabilityCol="probability", rawPredictionCol="rawPrediction"):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- threshold=None, thresholds=None, \
+ threshold=0.5, thresholds=None, \
probabilityCol="probability", rawPredictionCol="rawPrediction")
Sets params for logistic regression.
- Param thresholds overrides Param threshold; threshold is provided
- for backwards compatibility and only applies to binary classification.
+ If the threshold and thresholds Params are both set, they must be equivalent.
"""
- # Under the hood we use thresholds so translate threshold to thresholds if applicable
- if thresholds is None and threshold is not None:
- kwargs[thresholds] = [1-threshold, threshold]
kwargs = self.setParams._input_kwargs
- return self._set(**kwargs)
+ self._set(**kwargs)
+ self._checkThresholdConsistency()
+ return self
def _create_model(self, java_model):
return LogisticRegressionModel(java_model)
@@ -165,44 +170,65 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def setThreshold(self, value):
"""
- Sets the value of :py:attr:`thresholds` using [1-value, value].
+ Sets the value of :py:attr:`threshold`.
+ Clears value of :py:attr:`thresholds` if it has been set.
+ """
+ self._paramMap[self.threshold] = value
+ if self.isSet(self.thresholds):
+ del self._paramMap[self.thresholds]
+ return self
- >>> lr = LogisticRegression()
- >>> lr.getThreshold()
- 0.5
- >>> lr.setThreshold(0.6)
- LogisticRegression_...
- >>> abs(lr.getThreshold() - 0.6) < 1e-5
- True
+ def getThreshold(self):
+ """
+ Gets the value of threshold or its default value.
"""
- return self.setThresholds([1-value, value])
+ self._checkThresholdConsistency()
+ if self.isSet(self.thresholds):
+ ts = self.getOrDefault(self.thresholds)
+ if len(ts) != 2:
+ raise ValueError("Logistic Regression getThreshold only applies to" +
+ " binary classification, but thresholds has length != 2." +
+ " thresholds: " + ",".join(ts))
+ return 1.0/(1.0 + ts[0]/ts[1])
+ else:
+ return self.getOrDefault(self.threshold)
def setThresholds(self, value):
"""
Sets the value of :py:attr:`thresholds`.
+ Clears value of :py:attr:`threshold` if it has been set.
"""
self._paramMap[self.thresholds] = value
+ if self.isSet(self.threshold):
+ del self._paramMap[self.threshold]
return self
def getThresholds(self):
"""
- Gets the value of thresholds or its default value.
+ If :py:attr:`thresholds` is set, return its value.
+ Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary
+ classification: (1-threshold, threshold).
+ If neither are set, throw an error.
"""
- return self.getOrDefault(self.thresholds)
+ self._checkThresholdConsistency()
+ if not self.isSet(self.thresholds) and self.isSet(self.threshold):
+ t = self.getOrDefault(self.threshold)
+ return [1.0-t, t]
+ else:
+ return self.getOrDefault(self.thresholds)
- def getThreshold(self):
- """
- Gets the value of threshold or its default value.
- """
- if self.isDefined(self.thresholds):
- thresholds = self.getOrDefault(self.thresholds)
- if len(thresholds) != 2:
+ def _checkThresholdConsistency(self):
+ if self.isSet(self.threshold) and self.isSet(self.thresholds):
+ ts = self.getParam(self.thresholds)
+ if len(ts) != 2:
raise ValueError("Logistic Regression getThreshold only applies to" +
" binary classification, but thresholds has length != 2." +
- " thresholds: " + ",".join(thresholds))
- return 1.0/(1.0+thresholds[0]/thresholds[1])
- else:
- return 0.5
+ " thresholds: " + ",".join(ts))
+ t = 1.0/(1.0 + ts[0]/ts[1])
+ t2 = self.getParam(self.threshold)
+ if abs(t2 - t) >= 1E-5:
+ raise ValueError("Logistic Regression getThreshold found inconsistent values for" +
+ " threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
class LogisticRegressionModel(JavaModel):