From 551def5d6972440365bd7436d484a67138d9a8f3 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 12 Aug 2015 14:27:13 -0700 Subject: [SPARK-9789] [ML] Added logreg threshold param back Reinstated LogisticRegression.threshold Param for binary compatibility. Param thresholds overrides threshold, if set. CC: mengxr dbtsai feynmanliang Author: Joseph K. Bradley Closes #8079 from jkbradley/logreg-reinstate-threshold. --- python/pyspark/ml/classification.py | 98 +++++++++++++++++++++++-------------- 1 file changed, 62 insertions(+), 36 deletions(-) (limited to 'python/pyspark/ml/classification.py') diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 6702dce554..83f808efc3 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -76,19 +76,21 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti " Array must have length equal to the number of classes, with values >= 0." + " The class with largest value p/t is predicted, where p is the original" + " probability of that class and t is the class' threshold.") + threshold = Param(Params._dummy(), "threshold", + "Threshold in binary classification prediction, in range [0, 1]." + + " If threshold and thresholds are both set, they must match.") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=None, thresholds=None, + threshold=0.5, thresholds=None, probabilityCol="probability", rawPredictionCol="rawPrediction"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=None, thresholds=None, \ + threshold=0.5, thresholds=None, \ probabilityCol="probability", rawPredictionCol="rawPrediction") - Param thresholds overrides Param threshold; threshold is provided - for backwards compatibility and only applies to binary classification. + If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -101,7 +103,11 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") #: param for whether to fit an intercept term. self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") - #: param for threshold in binary classification prediction, in range [0, 1]. + #: param for threshold in binary classification, in range [0, 1]. + self.threshold = Param(self, "threshold", + "Threshold in binary classification prediction, in range [0, 1]." + + " If threshold and thresholds are both set, they must match.") + #: param for thresholds or cutoffs in binary or multiclass classification self.thresholds = \ Param(self, "thresholds", "Thresholds in multi-class classification" + @@ -110,29 +116,28 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti " The class with largest value p/t is predicted, where p is the original" + " probability of that class and t is the class' threshold.") self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, - fitIntercept=True) + fitIntercept=True, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) + self._checkThresholdConsistency() @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=None, thresholds=None, + threshold=0.5, thresholds=None, probabilityCol="probability", rawPredictionCol="rawPrediction"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=None, thresholds=None, \ + threshold=0.5, thresholds=None, \ probabilityCol="probability", rawPredictionCol="rawPrediction") Sets params for logistic regression. - Param thresholds overrides Param threshold; threshold is provided - for backwards compatibility and only applies to binary classification. + If the threshold and thresholds Params are both set, they must be equivalent. """ - # Under the hood we use thresholds so translate threshold to thresholds if applicable - if thresholds is None and threshold is not None: - kwargs[thresholds] = [1-threshold, threshold] kwargs = self.setParams._input_kwargs - return self._set(**kwargs) + self._set(**kwargs) + self._checkThresholdConsistency() + return self def _create_model(self, java_model): return LogisticRegressionModel(java_model) @@ -165,44 +170,65 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti def setThreshold(self, value): """ - Sets the value of :py:attr:`thresholds` using [1-value, value]. + Sets the value of :py:attr:`threshold`. + Clears value of :py:attr:`thresholds` if it has been set. + """ + self._paramMap[self.threshold] = value + if self.isSet(self.thresholds): + del self._paramMap[self.thresholds] + return self - >>> lr = LogisticRegression() - >>> lr.getThreshold() - 0.5 - >>> lr.setThreshold(0.6) - LogisticRegression_... - >>> abs(lr.getThreshold() - 0.6) < 1e-5 - True + def getThreshold(self): + """ + Gets the value of threshold or its default value. """ - return self.setThresholds([1-value, value]) + self._checkThresholdConsistency() + if self.isSet(self.thresholds): + ts = self.getOrDefault(self.thresholds) + if len(ts) != 2: + raise ValueError("Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2." + + " thresholds: " + ",".join(ts)) + return 1.0/(1.0 + ts[0]/ts[1]) + else: + return self.getOrDefault(self.threshold) def setThresholds(self, value): """ Sets the value of :py:attr:`thresholds`. + Clears value of :py:attr:`threshold` if it has been set. """ self._paramMap[self.thresholds] = value + if self.isSet(self.threshold): + del self._paramMap[self.threshold] return self def getThresholds(self): """ - Gets the value of thresholds or its default value. + If :py:attr:`thresholds` is set, return its value. + Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary + classification: (1-threshold, threshold). + If neither are set, throw an error. """ - return self.getOrDefault(self.thresholds) + self._checkThresholdConsistency() + if not self.isSet(self.thresholds) and self.isSet(self.threshold): + t = self.getOrDefault(self.threshold) + return [1.0-t, t] + else: + return self.getOrDefault(self.thresholds) - def getThreshold(self): - """ - Gets the value of threshold or its default value. - """ - if self.isDefined(self.thresholds): - thresholds = self.getOrDefault(self.thresholds) - if len(thresholds) != 2: + def _checkThresholdConsistency(self): + if self.isSet(self.threshold) and self.isSet(self.thresholds): + ts = self.getParam(self.thresholds) + if len(ts) != 2: raise ValueError("Logistic Regression getThreshold only applies to" + " binary classification, but thresholds has length != 2." + - " thresholds: " + ",".join(thresholds)) - return 1.0/(1.0+thresholds[0]/thresholds[1]) - else: - return 0.5 + " thresholds: " + ",".join(ts)) + t = 1.0/(1.0 + ts[0]/ts[1]) + t2 = self.getParam(self.threshold) + if abs(t2 - t) >= 1E-5: + raise ValueError("Logistic Regression getThreshold found inconsistent values for" + + " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) class LogisticRegressionModel(JavaModel): -- cgit v1.2.3