aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-12 14:27:13 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-12 14:27:13 -0700
commit551def5d6972440365bd7436d484a67138d9a8f3 (patch)
treeaf2280c3849497b4236099ec84fe7b4b64d63f2e /mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
parent762bacc16ac5e74c8b05a7c1e3e367d1d1633cef (diff)
downloadspark-551def5d6972440365bd7436d484a67138d9a8f3.tar.gz
spark-551def5d6972440365bd7436d484a67138d9a8f3.tar.bz2
spark-551def5d6972440365bd7436d484a67138d9a8f3.zip
[SPARK-9789] [ML] Added logreg threshold param back
Reinstated LogisticRegression.threshold Param for binary compatibility. Param thresholds overrides threshold, if set. CC: mengxr dbtsai feynmanliang Author: Joseph K. Bradley <joseph@databricks.com> Closes #8079 from jkbradley/logreg-reinstate-threshold.
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala127
1 files changed, 104 insertions, 23 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index f55134d258..5bcd7117b6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -34,8 +34,7 @@ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.storage.StorageLevel
/**
@@ -43,44 +42,115 @@ import org.apache.spark.storage.StorageLevel
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
- with HasStandardization {
+ with HasStandardization with HasThreshold {
/**
- * Version of setThresholds() for binary classification, available for backwards
- * compatibility.
+ * Set threshold in binary classification, in range [0, 1].
*
- * Calling this with threshold p will effectively call `setThresholds(Array(1-p, p))`.
+ * If the estimated probability of class label 1 is > threshold, then predict 1, else 0.
+ * A high threshold encourages the model to predict 0 more often;
+ * a low threshold encourages the model to predict 1 more often.
+ *
+ * Note: Calling this with threshold p is equivalent to calling `setThresholds(Array(1-p, p))`.
+ * When [[setThreshold()]] is called, any user-set value for [[thresholds]] will be cleared.
+ * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be
+ * equivalent.
+ *
+ * Default is 0.5.
+ * @group setParam
+ */
+ def setThreshold(value: Double): this.type = {
+ if (isSet(thresholds)) clear(thresholds)
+ set(threshold, value)
+ }
+
+ /**
+ * Get threshold for binary classification.
+ *
+ * If [[threshold]] is set, returns that value.
+ * Otherwise, if [[thresholds]] is set with length 2 (i.e., binary classification),
+ * this returns the equivalent threshold: {{{1 / (1 + thresholds(0) / thresholds(1))}}}.
+ * Otherwise, returns [[threshold]] default value.
+ *
+ * @group getParam
+ * @throws IllegalArgumentException if [[thresholds]] is set to an array of length other than 2.
+ */
+ override def getThreshold: Double = {
+ checkThresholdConsistency()
+ if (isSet(thresholds)) {
+ val ts = $(thresholds)
+ require(ts.length == 2, "Logistic Regression getThreshold only applies to" +
+ " binary classification, but thresholds has length != 2. thresholds: " + ts.mkString(","))
+ 1.0 / (1.0 + ts(0) / ts(1))
+ } else {
+ $(threshold)
+ }
+ }
+
+ /**
+ * Set thresholds in multiclass (or binary) classification to adjust the probability of
+ * predicting each class. Array must have length equal to the number of classes, with values >= 0.
+ * The class with largest value p/t is predicted, where p is the original probability of that
+ * class and t is the class' threshold.
+ *
+ * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared.
+ * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be
+ * equivalent.
*
- * Default is effectively 0.5.
* @group setParam
*/
- def setThreshold(value: Double): this.type = set(thresholds, Array(1.0 - value, value))
+ def setThresholds(value: Array[Double]): this.type = {
+ if (isSet(threshold)) clear(threshold)
+ set(thresholds, value)
+ }
/**
- * Version of [[getThresholds()]] for binary classification, available for backwards
- * compatibility.
+ * Get thresholds for binary or multiclass classification.
+ *
+ * If [[thresholds]] is set, return its value.
+ * Otherwise, if [[threshold]] is set, return the equivalent thresholds for binary
+ * classification: (1-threshold, threshold).
+ * If neither are set, throw an exception.
*
- * Param thresholds must have length 2 (or not be specified).
- * This returns {{{1 / (1 + thresholds(0) / thresholds(1))}}}.
* @group getParam
*/
- def getThreshold: Double = {
- if (isDefined(thresholds)) {
- val thresholdValues = $(thresholds)
- assert(thresholdValues.length == 2, "Logistic Regression getThreshold only applies to" +
- " binary classification, but thresholds has length != 2." +
- s" thresholds: ${thresholdValues.mkString(",")}")
- 1.0 / (1.0 + thresholdValues(0) / thresholdValues(1))
+ override def getThresholds: Array[Double] = {
+ checkThresholdConsistency()
+ if (!isSet(thresholds) && isSet(threshold)) {
+ val t = $(threshold)
+ Array(1-t, t)
} else {
- 0.5
+ $(thresholds)
+ }
+ }
+
+ /**
+ * If [[threshold]] and [[thresholds]] are both set, ensures they are consistent.
+ * @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent
+ */
+ protected def checkThresholdConsistency(): Unit = {
+ if (isSet(threshold) && isSet(thresholds)) {
+ val ts = $(thresholds)
+ require(ts.length == 2, "Logistic Regression found inconsistent values for threshold and" +
+ s" thresholds. Param threshold is set (${$(threshold)}), indicating binary" +
+ s" classification, but Param thresholds is set with length ${ts.length}." +
+ " Clear one Param value to fix this problem.")
+ val t = 1.0 / (1.0 + ts(0) / ts(1))
+ require(math.abs($(threshold) - t) < 1E-5, "Logistic Regression getThreshold found" +
+ s" inconsistent values for threshold (${$(threshold)}) and thresholds (equivalent to $t)")
}
}
+
+ override def validateParams(): Unit = {
+ checkThresholdConsistency()
+ }
}
/**
* :: Experimental ::
* Logistic regression.
- * Currently, this class only supports binary classification.
+ * Currently, this class only supports binary classification. It will support multiclass
+ * in the future.
*/
@Experimental
class LogisticRegression(override val uid: String)
@@ -128,7 +198,7 @@ class LogisticRegression(override val uid: String)
* Whether to fit an intercept term.
* Default is true.
* @group setParam
- * */
+ */
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)
@@ -140,7 +210,7 @@ class LogisticRegression(override val uid: String)
* is applied. In R's GLMNET package, the default behavior is true as well.
* Default is true.
* @group setParam
- * */
+ */
def setStandardization(value: Boolean): this.type = set(standardization, value)
setDefault(standardization -> true)
@@ -148,6 +218,10 @@ class LogisticRegression(override val uid: String)
override def getThreshold: Double = super.getThreshold
+ override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
+
+ override def getThresholds: Array[Double] = super.getThresholds
+
override protected def train(dataset: DataFrame): LogisticRegressionModel = {
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
val instances = extractLabeledPoints(dataset).map {
@@ -314,6 +388,10 @@ class LogisticRegressionModel private[ml] (
override def getThreshold: Double = super.getThreshold
+ override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
+
+ override def getThresholds: Array[Double] = super.getThresholds
+
/** Margin (rawPrediction) for class label 1. For binary classification only. */
private val margin: Vector => Double = (features) => {
BLAS.dot(features, weights) + intercept
@@ -364,6 +442,7 @@ class LogisticRegressionModel private[ml] (
* The behavior of this can be adjusted using [[thresholds]].
*/
override protected def predict(features: Vector): Double = {
+ // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
if (score(features) > getThreshold) 1 else 0
}
@@ -393,6 +472,7 @@ class LogisticRegressionModel private[ml] (
}
override protected def raw2prediction(rawPrediction: Vector): Double = {
+ // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
val t = getThreshold
val rawThreshold = if (t == 0.0) {
Double.NegativeInfinity
@@ -405,6 +485,7 @@ class LogisticRegressionModel private[ml] (
}
override protected def probability2prediction(probability: Vector): Double = {
+ // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
if (probability(1) > getThreshold) 1 else 0
}
}