aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-12 14:27:13 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-12 14:27:13 -0700
commit551def5d6972440365bd7436d484a67138d9a8f3 (patch)
treeaf2280c3849497b4236099ec84fe7b4b64d63f2e
parent762bacc16ac5e74c8b05a7c1e3e367d1d1633cef (diff)
downloadspark-551def5d6972440365bd7436d484a67138d9a8f3.tar.gz
spark-551def5d6972440365bd7436d484a67138d9a8f3.tar.bz2
spark-551def5d6972440365bd7436d484a67138d9a8f3.zip
[SPARK-9789] [ML] Added logreg threshold param back
Reinstated LogisticRegression.threshold Param for binary compatibility. Param thresholds overrides threshold, if set. CC: mengxr dbtsai feynmanliang Author: Joseph K. Bradley <joseph@databricks.com> Closes #8079 from jkbradley/logreg-reinstate-threshold.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala127
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala6
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala33
-rw-r--r--python/pyspark/ml/classification.py98
6 files changed, 199 insertions, 76 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index f55134d258..5bcd7117b6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -34,8 +34,7 @@ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.storage.StorageLevel
/**
@@ -43,44 +42,115 @@ import org.apache.spark.storage.StorageLevel
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
- with HasStandardization {
+ with HasStandardization with HasThreshold {
/**
- * Version of setThresholds() for binary classification, available for backwards
- * compatibility.
+ * Set threshold in binary classification, in range [0, 1].
*
- * Calling this with threshold p will effectively call `setThresholds(Array(1-p, p))`.
+ * If the estimated probability of class label 1 is > threshold, then predict 1, else 0.
+ * A high threshold encourages the model to predict 0 more often;
+ * a low threshold encourages the model to predict 1 more often.
+ *
+ * Note: Calling this with threshold p is equivalent to calling `setThresholds(Array(1-p, p))`.
+ * When [[setThreshold()]] is called, any user-set value for [[thresholds]] will be cleared.
+ * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be
+ * equivalent.
+ *
+ * Default is 0.5.
+ * @group setParam
+ */
+ def setThreshold(value: Double): this.type = {
+ if (isSet(thresholds)) clear(thresholds)
+ set(threshold, value)
+ }
+
+ /**
+ * Get threshold for binary classification.
+ *
+ * If [[threshold]] is set, returns that value.
+ * Otherwise, if [[thresholds]] is set with length 2 (i.e., binary classification),
+ * this returns the equivalent threshold: {{{1 / (1 + thresholds(0) / thresholds(1))}}}.
+ * Otherwise, returns [[threshold]] default value.
+ *
+ * @group getParam
+ * @throws IllegalArgumentException if [[thresholds]] is set to an array of length other than 2.
+ */
+ override def getThreshold: Double = {
+ checkThresholdConsistency()
+ if (isSet(thresholds)) {
+ val ts = $(thresholds)
+ require(ts.length == 2, "Logistic Regression getThreshold only applies to" +
+ " binary classification, but thresholds has length != 2. thresholds: " + ts.mkString(","))
+ 1.0 / (1.0 + ts(0) / ts(1))
+ } else {
+ $(threshold)
+ }
+ }
+
+ /**
+ * Set thresholds in multiclass (or binary) classification to adjust the probability of
+ * predicting each class. Array must have length equal to the number of classes, with values >= 0.
+ * The class with largest value p/t is predicted, where p is the original probability of that
+ * class and t is the class' threshold.
+ *
+ * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared.
+ * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be
+ * equivalent.
*
- * Default is effectively 0.5.
* @group setParam
*/
- def setThreshold(value: Double): this.type = set(thresholds, Array(1.0 - value, value))
+ def setThresholds(value: Array[Double]): this.type = {
+ if (isSet(threshold)) clear(threshold)
+ set(thresholds, value)
+ }
/**
- * Version of [[getThresholds()]] for binary classification, available for backwards
- * compatibility.
+ * Get thresholds for binary or multiclass classification.
+ *
+ * If [[thresholds]] is set, return its value.
+ * Otherwise, if [[threshold]] is set, return the equivalent thresholds for binary
+ * classification: (1-threshold, threshold).
+ * If neither are set, throw an exception.
*
- * Param thresholds must have length 2 (or not be specified).
- * This returns {{{1 / (1 + thresholds(0) / thresholds(1))}}}.
* @group getParam
*/
- def getThreshold: Double = {
- if (isDefined(thresholds)) {
- val thresholdValues = $(thresholds)
- assert(thresholdValues.length == 2, "Logistic Regression getThreshold only applies to" +
- " binary classification, but thresholds has length != 2." +
- s" thresholds: ${thresholdValues.mkString(",")}")
- 1.0 / (1.0 + thresholdValues(0) / thresholdValues(1))
+ override def getThresholds: Array[Double] = {
+ checkThresholdConsistency()
+ if (!isSet(thresholds) && isSet(threshold)) {
+ val t = $(threshold)
+ Array(1-t, t)
} else {
- 0.5
+ $(thresholds)
+ }
+ }
+
+ /**
+ * If [[threshold]] and [[thresholds]] are both set, ensures they are consistent.
+ * @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent
+ */
+ protected def checkThresholdConsistency(): Unit = {
+ if (isSet(threshold) && isSet(thresholds)) {
+ val ts = $(thresholds)
+ require(ts.length == 2, "Logistic Regression found inconsistent values for threshold and" +
+ s" thresholds. Param threshold is set (${$(threshold)}), indicating binary" +
+ s" classification, but Param thresholds is set with length ${ts.length}." +
+ " Clear one Param value to fix this problem.")
+ val t = 1.0 / (1.0 + ts(0) / ts(1))
+ require(math.abs($(threshold) - t) < 1E-5, "Logistic Regression getThreshold found" +
+ s" inconsistent values for threshold (${$(threshold)}) and thresholds (equivalent to $t)")
}
}
+
+ override def validateParams(): Unit = {
+ checkThresholdConsistency()
+ }
}
/**
* :: Experimental ::
* Logistic regression.
- * Currently, this class only supports binary classification.
+ * Currently, this class only supports binary classification. It will support multiclass
+ * in the future.
*/
@Experimental
class LogisticRegression(override val uid: String)
@@ -128,7 +198,7 @@ class LogisticRegression(override val uid: String)
* Whether to fit an intercept term.
* Default is true.
* @group setParam
- * */
+ */
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)
@@ -140,7 +210,7 @@ class LogisticRegression(override val uid: String)
* is applied. In R's GLMNET package, the default behavior is true as well.
* Default is true.
* @group setParam
- * */
+ */
def setStandardization(value: Boolean): this.type = set(standardization, value)
setDefault(standardization -> true)
@@ -148,6 +218,10 @@ class LogisticRegression(override val uid: String)
override def getThreshold: Double = super.getThreshold
+ override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
+
+ override def getThresholds: Array[Double] = super.getThresholds
+
override protected def train(dataset: DataFrame): LogisticRegressionModel = {
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
val instances = extractLabeledPoints(dataset).map {
@@ -314,6 +388,10 @@ class LogisticRegressionModel private[ml] (
override def getThreshold: Double = super.getThreshold
+ override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
+
+ override def getThresholds: Array[Double] = super.getThresholds
+
/** Margin (rawPrediction) for class label 1. For binary classification only. */
private val margin: Vector => Double = (features) => {
BLAS.dot(features, weights) + intercept
@@ -364,6 +442,7 @@ class LogisticRegressionModel private[ml] (
* The behavior of this can be adjusted using [[thresholds]].
*/
override protected def predict(features: Vector): Double = {
+ // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
if (score(features) > getThreshold) 1 else 0
}
@@ -393,6 +472,7 @@ class LogisticRegressionModel private[ml] (
}
override protected def raw2prediction(rawPrediction: Vector): Double = {
+ // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
val t = getThreshold
val rawThreshold = if (t == 0.0) {
Double.NegativeInfinity
@@ -405,6 +485,7 @@ class LogisticRegressionModel private[ml] (
}
override protected def probability2prediction(probability: Vector): Double = {
+ // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
if (probability(1) > getThreshold) 1 else 0
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index da4c076830..9e12f1856a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -45,14 +45,14 @@ private[shared] object SharedParamsCodeGen {
" These probabilities should be treated as confidences, not precise probabilities.",
Some("\"probability\"")),
ParamDesc[Double]("threshold",
- "threshold in binary classification prediction, in range [0, 1]",
+ "threshold in binary classification prediction, in range [0, 1]", Some("0.5"),
isValid = "ParamValidators.inRange(0, 1)", finalMethods = false),
ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" +
" to adjust the probability of predicting each class." +
" Array must have length equal to the number of classes, with values >= 0." +
" The class with largest value p/t is predicted, where p is the original probability" +
" of that class and t is the class' threshold.",
- isValid = "(t: Array[Double]) => t.forall(_ >= 0)"),
+ isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false),
ParamDesc[String]("inputCol", "input column name"),
ParamDesc[Array[String]]("inputCols", "input column names"),
ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 23e2b6cc43..a17d4ea960 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -139,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params {
}
/**
- * Trait for shared param threshold.
+ * Trait for shared param threshold (default: 0.5).
*/
private[ml] trait HasThreshold extends Params {
@@ -149,6 +149,8 @@ private[ml] trait HasThreshold extends Params {
*/
final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1))
+ setDefault(threshold, 0.5)
+
/** @group getParam */
def getThreshold: Double = $(threshold)
}
@@ -165,7 +167,7 @@ private[ml] trait HasThresholds extends Params {
final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", (t: Array[Double]) => t.forall(_ >= 0))
/** @group getParam */
- final def getThresholds: Array[Double] = $(thresholds)
+ def getThresholds: Array[Double] = $(thresholds)
}
/**
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 7e9aa38372..618b95b9bd 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -100,9 +100,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
assert(r.getDouble(0) == 0.0);
}
// Call transform with params, and check that the params worked.
- double[] thresholds = {1.0, 0.0};
- model.transform(
- dataset, model.thresholds().w(thresholds), model.probabilityCol().w("myProb"))
+ model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
.registerTempTable("predNotAllZero");
DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
boolean foundNonZero = false;
@@ -112,9 +110,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
assert(foundNonZero);
// Call fit() with new params, and check as many params as we can.
- double[] thresholds2 = {0.6, 0.4};
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
- lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb"));
+ lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
LogisticRegression parent2 = (LogisticRegression) model2.parent();
assert(parent2.getMaxIter() == 5);
assert(parent2.getRegParam() == 0.1);
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 8c3d4590f5..e354e161c6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -94,12 +94,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
test("setThreshold, getThreshold") {
val lr = new LogisticRegression
// default
- withClue("LogisticRegression should not have thresholds set by default") {
- intercept[java.util.NoSuchElementException] {
+ assert(lr.getThreshold === 0.5, "LogisticRegression.threshold should default to 0.5")
+ withClue("LogisticRegression should not have thresholds set by default.") {
+ intercept[java.util.NoSuchElementException] { // Note: The exception type may change in future
lr.getThresholds
}
}
- // Set via thresholds.
+ // Set via threshold.
// Intuition: Large threshold or large thresholds(1) makes class 0 more likely.
lr.setThreshold(1.0)
assert(lr.getThresholds === Array(0.0, 1.0))
@@ -107,10 +108,26 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(lr.getThresholds === Array(1.0, 0.0))
lr.setThreshold(0.5)
assert(lr.getThresholds === Array(0.5, 0.5))
- // Test getThreshold
- lr.setThresholds(Array(0.3, 0.7))
+ // Set via thresholds
+ val lr2 = new LogisticRegression
+ lr2.setThresholds(Array(0.3, 0.7))
val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7)
- assert(lr.getThreshold ~== expectedThreshold relTol 1E-7)
+ assert(lr2.getThreshold ~== expectedThreshold relTol 1E-7)
+ // thresholds and threshold must be consistent
+ lr2.setThresholds(Array(0.1, 0.2, 0.3))
+ withClue("getThreshold should throw error if thresholds has length != 2.") {
+ intercept[IllegalArgumentException] {
+ lr2.getThreshold
+ }
+ }
+ // thresholds and threshold must be consistent: values
+ withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
+ intercept[IllegalArgumentException] {
+ val lr2model = lr2.fit(dataset,
+ lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0))
+ lr2model.getThreshold
+ }
+ }
}
test("logistic regression doesn't fit intercept when fitIntercept is off") {
@@ -145,7 +162,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
// Call transform with params, and check that the params worked.
val predNotAllZero =
- model.transform(dataset, model.thresholds -> Array(1.0, 0.0),
+ model.transform(dataset, model.threshold -> 0.0,
model.probabilityCol -> "myProb")
.select("prediction", "myProb")
.collect()
@@ -153,8 +170,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(predNotAllZero.exists(_ !== 0.0))
// Call fit() with new params, and check as many params as we can.
+ lr.setThresholds(Array(0.6, 0.4))
val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1,
- lr.thresholds -> Array(0.6, 0.4),
lr.probabilityCol -> "theProb")
val parent2 = model2.parent.asInstanceOf[LogisticRegression]
assert(parent2.getMaxIter === 5)
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 6702dce554..83f808efc3 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -76,19 +76,21 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
" Array must have length equal to the number of classes, with values >= 0." +
" The class with largest value p/t is predicted, where p is the original" +
" probability of that class and t is the class' threshold.")
+ threshold = Param(Params._dummy(), "threshold",
+ "Threshold in binary classification prediction, in range [0, 1]." +
+ " If threshold and thresholds are both set, they must match.")
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- threshold=None, thresholds=None,
+ threshold=0.5, thresholds=None,
probabilityCol="probability", rawPredictionCol="rawPrediction"):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- threshold=None, thresholds=None, \
+ threshold=0.5, thresholds=None, \
probabilityCol="probability", rawPredictionCol="rawPrediction")
- Param thresholds overrides Param threshold; threshold is provided
- for backwards compatibility and only applies to binary classification.
+ If the threshold and thresholds Params are both set, they must be equivalent.
"""
super(LogisticRegression, self).__init__()
self._java_obj = self._new_java_obj(
@@ -101,7 +103,11 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
"the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
#: param for whether to fit an intercept term.
self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.")
- #: param for threshold in binary classification prediction, in range [0, 1].
+ #: param for threshold in binary classification, in range [0, 1].
+ self.threshold = Param(self, "threshold",
+ "Threshold in binary classification prediction, in range [0, 1]." +
+ " If threshold and thresholds are both set, they must match.")
+ #: param for thresholds or cutoffs in binary or multiclass classification
self.thresholds = \
Param(self, "thresholds",
"Thresholds in multi-class classification" +
@@ -110,29 +116,28 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
" The class with largest value p/t is predicted, where p is the original" +
" probability of that class and t is the class' threshold.")
self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6,
- fitIntercept=True)
+ fitIntercept=True, threshold=0.5)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
+ self._checkThresholdConsistency()
@keyword_only
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- threshold=None, thresholds=None,
+ threshold=0.5, thresholds=None,
probabilityCol="probability", rawPredictionCol="rawPrediction"):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- threshold=None, thresholds=None, \
+ threshold=0.5, thresholds=None, \
probabilityCol="probability", rawPredictionCol="rawPrediction")
Sets params for logistic regression.
- Param thresholds overrides Param threshold; threshold is provided
- for backwards compatibility and only applies to binary classification.
+ If the threshold and thresholds Params are both set, they must be equivalent.
"""
- # Under the hood we use thresholds so translate threshold to thresholds if applicable
- if thresholds is None and threshold is not None:
- kwargs[thresholds] = [1-threshold, threshold]
kwargs = self.setParams._input_kwargs
- return self._set(**kwargs)
+ self._set(**kwargs)
+ self._checkThresholdConsistency()
+ return self
def _create_model(self, java_model):
return LogisticRegressionModel(java_model)
@@ -165,44 +170,65 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def setThreshold(self, value):
"""
- Sets the value of :py:attr:`thresholds` using [1-value, value].
+ Sets the value of :py:attr:`threshold`.
+ Clears value of :py:attr:`thresholds` if it has been set.
+ """
+ self._paramMap[self.threshold] = value
+ if self.isSet(self.thresholds):
+ del self._paramMap[self.thresholds]
+ return self
- >>> lr = LogisticRegression()
- >>> lr.getThreshold()
- 0.5
- >>> lr.setThreshold(0.6)
- LogisticRegression_...
- >>> abs(lr.getThreshold() - 0.6) < 1e-5
- True
+ def getThreshold(self):
+ """
+ Gets the value of threshold or its default value.
"""
- return self.setThresholds([1-value, value])
+ self._checkThresholdConsistency()
+ if self.isSet(self.thresholds):
+ ts = self.getOrDefault(self.thresholds)
+ if len(ts) != 2:
+ raise ValueError("Logistic Regression getThreshold only applies to" +
+ " binary classification, but thresholds has length != 2." +
+ " thresholds: " + ",".join(ts))
+ return 1.0/(1.0 + ts[0]/ts[1])
+ else:
+ return self.getOrDefault(self.threshold)
def setThresholds(self, value):
"""
Sets the value of :py:attr:`thresholds`.
+ Clears value of :py:attr:`threshold` if it has been set.
"""
self._paramMap[self.thresholds] = value
+ if self.isSet(self.threshold):
+ del self._paramMap[self.threshold]
return self
def getThresholds(self):
"""
- Gets the value of thresholds or its default value.
+ If :py:attr:`thresholds` is set, return its value.
+ Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary
+ classification: (1-threshold, threshold).
+ If neither are set, throw an error.
"""
- return self.getOrDefault(self.thresholds)
+ self._checkThresholdConsistency()
+ if not self.isSet(self.thresholds) and self.isSet(self.threshold):
+ t = self.getOrDefault(self.threshold)
+ return [1.0-t, t]
+ else:
+ return self.getOrDefault(self.thresholds)
- def getThreshold(self):
- """
- Gets the value of threshold or its default value.
- """
- if self.isDefined(self.thresholds):
- thresholds = self.getOrDefault(self.thresholds)
- if len(thresholds) != 2:
+ def _checkThresholdConsistency(self):
+ if self.isSet(self.threshold) and self.isSet(self.thresholds):
+ ts = self.getParam(self.thresholds)
+ if len(ts) != 2:
raise ValueError("Logistic Regression getThreshold only applies to" +
" binary classification, but thresholds has length != 2." +
- " thresholds: " + ",".join(thresholds))
- return 1.0/(1.0+thresholds[0]/thresholds[1])
- else:
- return 0.5
+ " thresholds: " + ",".join(ts))
+ t = 1.0/(1.0 + ts[0]/ts[1])
+ t2 = self.getParam(self.threshold)
+ if abs(t2 - t) >= 1E-5:
+ raise ValueError("Logistic Regression getThreshold found inconsistent values for" +
+ " threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
class LogisticRegressionModel(JavaModel):