aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-08-04 10:12:22 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-04 10:12:22 -0700
commit5a23213c148bfe362514f9c71f5273ebda0a848a (patch)
tree1e2646c72d94b36387581ee8b5d99e14305fe650 /python
parent34a0eb2e89d59b0823efc035ddf2dc93f19540c1 (diff)
downloadspark-5a23213c148bfe362514f9c71f5273ebda0a848a.tar.gz
spark-5a23213c148bfe362514f9c71f5273ebda0a848a.tar.bz2
spark-5a23213c148bfe362514f9c71f5273ebda0a848a.zip
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification. Note that the primary author of this PR is holdenk Author: Holden Karau <holden@pigscanfly.ca> Author: Joseph K. Bradley <joseph@databricks.com> Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits: 3952977 [Joseph K. Bradley] fixed pyspark doc test 85febc8 [Joseph K. Bradley] made python unit tests a little more robust 7eb1d86 [Joseph K. Bradley] small cleanups 6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues. 0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests 7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc. 6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests 25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression c02d6c0 [Holden Karau] No default for thresholds 5e43628 [Holden Karau] CR feedback and fixed the renamed test f3fbbd1 [Holden Karau] revert the changes to random forest :( 51f581c [Holden Karau] Add explicit types to public methods, fix long line f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic 398078a [Holden Karau] move the thresholding around a bunch based on the design doc 4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok) 638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test e09919c [Holden Karau] Fix return type, I need more coffee.... 8d92cac [Holden Karau] Use ClassifierParams as the head 3456ed3 [Holden Karau] Add explicit return types even though just test a0f3b0c [Holden Karau] scala style fixes 6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now ffc8dab [Holden Karau] Update the sharedParams 0420290 [Holden Karau] Allow us to override the get methods selectively 978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions 1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there" 1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there efb9084 [Holden Karau] move setThresholds only to where its used 6b34809 [Holden Karau] Add a test with thresholding for the RFCS 74f54c3 [Holden Karau] Fix creation of vote array 1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down. 2f44b18 [Holden Karau] Add a global default of null for thresholds param f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds" 634b06f [Holden Karau] Some progress towards unifying threshold and thresholds 85c9e01 [Holden Karau] Test passes again... little fnur 099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer) 0f46836 [Holden Karau] Start adding a classifiersuite f70eb5e [Holden Karau] Fix test compile issues a7d59c8 [Holden Karau] Move thresholding into Classifier trait 5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test) 1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation 31d6bf2 [Holden Karau] Start threading the threshold info through 0ef228c [Holden Karau] Add hasthresholds
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/classification.py72
1 files changed, 60 insertions, 12 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index b5814f76de..291320f881 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -69,17 +69,25 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
"the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
"the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.")
- threshold = Param(Params._dummy(), "threshold",
- "threshold in binary classification prediction, in range [0, 1].")
+ thresholds = Param(Params._dummy(), "thresholds",
+ "Thresholds in multi-class classification" +
+ " to adjust the probability of predicting each class." +
+ " Array must have length equal to the number of classes, with values >= 0." +
+ " The class with largest value p/t is predicted, where p is the original" +
+ " probability of that class and t is the class' threshold.")
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction"):
+ threshold=None, thresholds=None,
+ probabilityCol="probability", rawPredictionCol="rawPrediction"):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction")
+ threshold=None, thresholds=None, \
+ probabilityCol="probability", rawPredictionCol="rawPrediction")
+ Param thresholds overrides Param threshold; threshold is provided
+ for backwards compatibility and only applies to binary classification.
"""
super(LogisticRegression, self).__init__()
self._java_obj = self._new_java_obj(
@@ -93,23 +101,35 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
#: param for whether to fit an intercept term.
self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.")
#: param for threshold in binary classification prediction, in range [0, 1].
- self.threshold = Param(self, "threshold",
- "threshold in binary classification prediction, in range [0, 1].")
+ self.thresholds = \
+ Param(self, "thresholds",
+ "Thresholds in multi-class classification" +
+ " to adjust the probability of predicting each class." +
+ " Array must have length equal to the number of classes, with values >= 0." +
+ " The class with largest value p/t is predicted, where p is the original" +
+ " probability of that class and t is the class' threshold.")
self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6,
- fitIntercept=True, threshold=0.5)
+ fitIntercept=True)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction"):
+ threshold=None, thresholds=None,
+ probabilityCol="probability", rawPredictionCol="rawPrediction"):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction")
+ threshold=None, thresholds=None, \
+ probabilityCol="probability", rawPredictionCol="rawPrediction")
Sets params for logistic regression.
+ Param thresholds overrides Param threshold; threshold is provided
+ for backwards compatibility and only applies to binary classification.
"""
+ # Under the hood we use thresholds so translate threshold to thresholds if applicable
+ if thresholds is None and threshold is not None:
+ kwargs[thresholds] = [1-threshold, threshold]
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
@@ -144,16 +164,44 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def setThreshold(self, value):
"""
- Sets the value of :py:attr:`threshold`.
+ Sets the value of :py:attr:`thresholds` using [1-value, value].
+
+ >>> lr = LogisticRegression()
+ >>> lr.getThreshold()
+ 0.5
+ >>> lr.setThreshold(0.6)
+ LogisticRegression_...
+ >>> abs(lr.getThreshold() - 0.6) < 1e-5
+ True
+ """
+ return self.setThresholds([1-value, value])
+
+ def setThresholds(self, value):
+ """
+ Sets the value of :py:attr:`thresholds`.
"""
- self._paramMap[self.threshold] = value
+ self._paramMap[self.thresholds] = value
return self
+ def getThresholds(self):
+ """
+ Gets the value of thresholds or its default value.
+ """
+ return self.getOrDefault(self.thresholds)
+
def getThreshold(self):
"""
Gets the value of threshold or its default value.
"""
- return self.getOrDefault(self.threshold)
+ if self.isDefined(self.thresholds):
+ thresholds = self.getOrDefault(self.thresholds)
+ if len(thresholds) != 2:
+ raise ValueError("Logistic Regression getThreshold only applies to" +
+ " binary classification, but thresholds has length != 2." +
+ " thresholds: " + ",".join(thresholds))
+ return 1.0/(1.0+thresholds[0]/thresholds[1])
+ else:
+ return 0.5
class LogisticRegressionModel(JavaModel):