aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-12 14:27:13 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-12 14:27:13 -0700
commit551def5d6972440365bd7436d484a67138d9a8f3 (patch)
treeaf2280c3849497b4236099ec84fe7b4b64d63f2e /mllib/src/test
parent762bacc16ac5e74c8b05a7c1e3e367d1d1633cef (diff)
downloadspark-551def5d6972440365bd7436d484a67138d9a8f3.tar.gz
spark-551def5d6972440365bd7436d484a67138d9a8f3.tar.bz2
spark-551def5d6972440365bd7436d484a67138d9a8f3.zip
[SPARK-9789] [ML] Added logreg threshold param back
Reinstated LogisticRegression.threshold Param for binary compatibility. Param thresholds overrides threshold, if set. CC: mengxr dbtsai feynmanliang Author: Joseph K. Bradley <joseph@databricks.com> Closes #8079 from jkbradley/logreg-reinstate-threshold.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala33
2 files changed, 27 insertions, 13 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 7e9aa38372..618b95b9bd 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -100,9 +100,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
assert(r.getDouble(0) == 0.0);
}
// Call transform with params, and check that the params worked.
- double[] thresholds = {1.0, 0.0};
- model.transform(
- dataset, model.thresholds().w(thresholds), model.probabilityCol().w("myProb"))
+ model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
.registerTempTable("predNotAllZero");
DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
boolean foundNonZero = false;
@@ -112,9 +110,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
assert(foundNonZero);
// Call fit() with new params, and check as many params as we can.
- double[] thresholds2 = {0.6, 0.4};
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
- lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb"));
+ lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
LogisticRegression parent2 = (LogisticRegression) model2.parent();
assert(parent2.getMaxIter() == 5);
assert(parent2.getRegParam() == 0.1);
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 8c3d4590f5..e354e161c6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -94,12 +94,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
test("setThreshold, getThreshold") {
val lr = new LogisticRegression
// default
- withClue("LogisticRegression should not have thresholds set by default") {
- intercept[java.util.NoSuchElementException] {
+ assert(lr.getThreshold === 0.5, "LogisticRegression.threshold should default to 0.5")
+ withClue("LogisticRegression should not have thresholds set by default.") {
+ intercept[java.util.NoSuchElementException] { // Note: The exception type may change in future
lr.getThresholds
}
}
- // Set via thresholds.
+ // Set via threshold.
// Intuition: Large threshold or large thresholds(1) makes class 0 more likely.
lr.setThreshold(1.0)
assert(lr.getThresholds === Array(0.0, 1.0))
@@ -107,10 +108,26 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(lr.getThresholds === Array(1.0, 0.0))
lr.setThreshold(0.5)
assert(lr.getThresholds === Array(0.5, 0.5))
- // Test getThreshold
- lr.setThresholds(Array(0.3, 0.7))
+ // Set via thresholds
+ val lr2 = new LogisticRegression
+ lr2.setThresholds(Array(0.3, 0.7))
val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7)
- assert(lr.getThreshold ~== expectedThreshold relTol 1E-7)
+ assert(lr2.getThreshold ~== expectedThreshold relTol 1E-7)
+ // thresholds and threshold must be consistent
+ lr2.setThresholds(Array(0.1, 0.2, 0.3))
+ withClue("getThreshold should throw error if thresholds has length != 2.") {
+ intercept[IllegalArgumentException] {
+ lr2.getThreshold
+ }
+ }
+ // thresholds and threshold must be consistent: values
+ withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
+ intercept[IllegalArgumentException] {
+ val lr2model = lr2.fit(dataset,
+ lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0))
+ lr2model.getThreshold
+ }
+ }
}
test("logistic regression doesn't fit intercept when fitIntercept is off") {
@@ -145,7 +162,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
// Call transform with params, and check that the params worked.
val predNotAllZero =
- model.transform(dataset, model.thresholds -> Array(1.0, 0.0),
+ model.transform(dataset, model.threshold -> 0.0,
model.probabilityCol -> "myProb")
.select("prediction", "myProb")
.collect()
@@ -153,8 +170,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(predNotAllZero.exists(_ !== 0.0))
// Call fit() with new params, and check as many params as we can.
+ lr.setThresholds(Array(0.6, 0.4))
val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1,
- lr.thresholds -> Array(0.6, 0.4),
lr.probabilityCol -> "theProb")
val parent2 = model2.parent.asInstanceOf[LogisticRegression]
assert(parent2.getMaxIter === 5)