diff options
author | Holden Karau <holden@pigscanfly.ca> | 2015-08-04 10:12:22 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-08-04 10:12:33 -0700 |
commit | c5250ddc5242a071549e980f69fa8bd785168979 (patch) | |
tree | 4a2456be89a08cafc4d19ab453b604c18a964c50 /python | |
parent | a9277cd5aedd570f550e2a807768c8ffada9576f (diff) | |
download | spark-c5250ddc5242a071549e980f69fa8bd785168979.tar.gz spark-c5250ddc5242a071549e980f69fa8bd785168979.tar.bz2 spark-c5250ddc5242a071549e980f69fa8bd785168979.zip |
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
(cherry picked from commit 5a23213c148bfe362514f9c71f5273ebda0a848a)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/ml/classification.py | 72 |
1 files changed, 60 insertions, 12 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index b5814f76de..291320f881 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -69,17 +69,25 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") - threshold = Param(Params._dummy(), "threshold", - "threshold in binary classification prediction, in range [0, 1].") + thresholds = Param(Params._dummy(), "thresholds", + "Thresholds in multi-class classification" + + " to adjust the probability of predicting each class." + + " Array must have length equal to the number of classes, with values >= 0." + + " The class with largest value p/t is predicted, where p is the original" + + " probability of that class and t is the class' threshold.") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=None, thresholds=None, + probabilityCol="probability", rawPredictionCol="rawPrediction"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=None, thresholds=None, \ + probabilityCol="probability", rawPredictionCol="rawPrediction") + Param thresholds overrides Param threshold; threshold is provided + for backwards compatibility and only applies to binary classification. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -93,23 +101,35 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti #: param for whether to fit an intercept term. self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") #: param for threshold in binary classification prediction, in range [0, 1]. - self.threshold = Param(self, "threshold", - "threshold in binary classification prediction, in range [0, 1].") + self.thresholds = \ + Param(self, "thresholds", + "Thresholds in multi-class classification" + + " to adjust the probability of predicting each class." + + " Array must have length equal to the number of classes, with values >= 0." + + " The class with largest value p/t is predicted, where p is the original" + + " probability of that class and t is the class' threshold.") self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, - fitIntercept=True, threshold=0.5) + fitIntercept=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=None, thresholds=None, + probabilityCol="probability", rawPredictionCol="rawPrediction"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=None, thresholds=None, \ + probabilityCol="probability", rawPredictionCol="rawPrediction") Sets params for logistic regression. + Param thresholds overrides Param threshold; threshold is provided + for backwards compatibility and only applies to binary classification. """ + # Under the hood we use thresholds so translate threshold to thresholds if applicable + if thresholds is None and threshold is not None: + kwargs[thresholds] = [1-threshold, threshold] kwargs = self.setParams._input_kwargs return self._set(**kwargs) @@ -144,16 +164,44 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti def setThreshold(self, value): """ - Sets the value of :py:attr:`threshold`. + Sets the value of :py:attr:`thresholds` using [1-value, value]. + + >>> lr = LogisticRegression() + >>> lr.getThreshold() + 0.5 + >>> lr.setThreshold(0.6) + LogisticRegression_... + >>> abs(lr.getThreshold() - 0.6) < 1e-5 + True + """ + return self.setThresholds([1-value, value]) + + def setThresholds(self, value): + """ + Sets the value of :py:attr:`thresholds`. """ - self._paramMap[self.threshold] = value + self._paramMap[self.thresholds] = value return self + def getThresholds(self): + """ + Gets the value of thresholds or its default value. + """ + return self.getOrDefault(self.thresholds) + def getThreshold(self): """ Gets the value of threshold or its default value. """ - return self.getOrDefault(self.threshold) + if self.isDefined(self.thresholds): + thresholds = self.getOrDefault(self.thresholds) + if len(thresholds) != 2: + raise ValueError("Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2." + + " thresholds: " + ",".join(thresholds)) + return 1.0/(1.0+thresholds[0]/thresholds[1]) + else: + return 0.5 class LogisticRegressionModel(JavaModel): |