aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2016-09-24 08:15:55 +0100
committerSean Owen <sowen@cloudera.com>2016-09-24 08:15:55 +0100
commit248916f5589155c0c3e93c3874781f17b08d598d (patch)
tree4e3183ffc5d59e09edc8b54ddc2af4fc67abb05b
parentf3fe55439e4c865c26502487a1bccf255da33f4a (diff)
downloadspark-248916f5589155c0c3e93c3874781f17b08d598d.tar.gz
spark-248916f5589155c0c3e93c3874781f17b08d598d.tar.bz2
spark-248916f5589155c0c3e93c3874781f17b08d598d.zip
[SPARK-17057][ML] ProbabilisticClassifierModels' thresholds should have at most one 0
## What changes were proposed in this pull request? Match ProbabilisticClassifer.thresholds requirements to R randomForest cutoff, requiring all > 0 ## How was this patch tested? Jenkins tests plus new test cases Author: Sean Owen <sowen@cloudera.com> Closes #15149 from srowen/SPARK-17057.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala35
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py5
-rw-r--r--python/pyspark/ml/param/shared.py4
7 files changed, 52 insertions, 29 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 343d50c790..5ab63d1de9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -123,9 +123,10 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
/**
* Set thresholds in multiclass (or binary) classification to adjust the probability of
- * predicting each class. Array must have length equal to the number of classes, with values >= 0.
+ * predicting each class. Array must have length equal to the number of classes, with values > 0,
+ * excepting that at most one value may be 0.
* The class with largest value p/t is predicted, where p is the original probability of that
- * class and t is the class' threshold.
+ * class and t is the class's threshold.
*
* Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared.
* If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 1b6e77542c..e89da6ff8b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors, VectorUDT}
+import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -200,22 +200,20 @@ abstract class ProbabilisticClassificationModel[
if (!isDefined(thresholds)) {
probability.argmax
} else {
- val thresholds: Array[Double] = getThresholds
- val probabilities = probability.toArray
+ val thresholds = getThresholds
var argMax = 0
var max = Double.NegativeInfinity
var i = 0
val probabilitySize = probability.size
while (i < probabilitySize) {
- if (thresholds(i) == 0.0) {
- max = Double.PositiveInfinity
+ // Thresholds are all > 0, excepting that at most one may be 0.
+ // The single class whose threshold is 0, if any, will always be predicted
+ // ('scaled' = +Infinity). However in the case that this class also has
+ // 0 probability, the class will not be selected ('scaled' is NaN).
+ val scaled = probability(i) / thresholds(i)
+ if (scaled > max) {
+ max = scaled
argMax = i
- } else {
- val scaled = probabilities(i) / thresholds(i)
- if (scaled > max) {
- max = scaled
- argMax = i
- }
}
i += 1
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 480b03d0f3..c94b8b4e9d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -50,10 +50,12 @@ private[shared] object SharedParamsCodeGen {
isValid = "ParamValidators.inRange(0, 1)", finalMethods = false),
ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" +
" to adjust the probability of predicting each class." +
- " Array must have length equal to the number of classes, with values >= 0." +
+ " Array must have length equal to the number of classes, with values > 0" +
+ " excepting that at most one value may be 0." +
" The class with largest value p/t is predicted, where p is the original probability" +
- " of that class and t is the class' threshold",
- isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false),
+ " of that class and t is the class's threshold",
+ isValid = "(t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1",
+ finalMethods = false),
ParamDesc[String]("inputCol", "input column name"),
ParamDesc[Array[String]]("inputCols", "input column names"),
ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 9125d9e19b..fa4530927e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -176,10 +176,10 @@ private[ml] trait HasThreshold extends Params {
private[ml] trait HasThresholds extends Params {
/**
- * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.
+ * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.
* @group param
*/
- final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0))
+ final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold", (t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1)
/** @group getParam */
def getThresholds: Array[Double] = $(thresholds)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index b3bd2b3e57..172c64aab9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -36,8 +36,8 @@ final class TestProbabilisticClassificationModel(
rawPrediction
}
- def friendlyPredict(input: Vector): Double = {
- predict(input)
+ def friendlyPredict(values: Double*): Double = {
+ predict(Vectors.dense(values.toArray))
}
}
@@ -45,16 +45,37 @@ final class TestProbabilisticClassificationModel(
class ProbabilisticClassifierSuite extends SparkFunSuite {
test("test thresholding") {
- val thresholds = Array(0.5, 0.2)
val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
- .setThresholds(thresholds)
- assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0)
- assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0)
+ .setThresholds(Array(0.5, 0.2))
+ assert(testModel.friendlyPredict(1.0, 1.0) === 1.0)
+ assert(testModel.friendlyPredict(1.0, 0.2) === 0.0)
}
test("test thresholding not required") {
val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
- assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
+ assert(testModel.friendlyPredict(1.0, 2.0) === 1.0)
+ }
+
+ test("test tiebreak") {
+ val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
+ .setThresholds(Array(0.4, 0.4))
+ assert(testModel.friendlyPredict(0.6, 0.6) === 0.0)
+ }
+
+ test("test one zero threshold") {
+ val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
+ .setThresholds(Array(0.0, 0.1))
+ assert(testModel.friendlyPredict(1.0, 10.0) === 0.0)
+ assert(testModel.friendlyPredict(0.0, 10.0) === 1.0)
+ }
+
+ test("bad thresholds") {
+ intercept[IllegalArgumentException] {
+ new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(0.0, 0.0))
+ }
+ intercept[IllegalArgumentException] {
+ new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(-0.1, 0.1))
+ }
}
}
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 4f4328bcad..929591236d 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -139,8 +139,9 @@ if __name__ == "__main__":
"model.", "True", "TypeConverters.toBoolean"),
("thresholds", "Thresholds in multi-class classification to adjust the probability of " +
"predicting each class. Array must have length equal to the number of classes, with " +
- "values >= 0. The class with largest value p/t is predicted, where p is the original " +
- "probability of that class and t is the class' threshold.", None,
+ "values > 0, excepting that at most one value may be 0. " +
+ "The class with largest value p/t is predicted, where p is the original " +
+ "probability of that class and t is the class's threshold.", None,
"TypeConverters.toListFloat"),
("weightCol", "weight column name. If this is not set or empty, we treat " +
"all instance weights as 1.0.", None, "TypeConverters.toString"),
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 24af07afc7..cc596936d8 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -469,10 +469,10 @@ class HasStandardization(Params):
class HasThresholds(Params):
"""
- Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.
+ Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.
"""
- thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", typeConverter=TypeConverters.toListFloat)
+ thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.", typeConverter=TypeConverters.toListFloat)
def __init__(self):
super(HasThresholds, self).__init__()