aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-12 14:27:13 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-12 14:27:13 -0700
commit551def5d6972440365bd7436d484a67138d9a8f3 (patch)
treeaf2280c3849497b4236099ec84fe7b4b64d63f2e /mllib/src/test/java/org
parent762bacc16ac5e74c8b05a7c1e3e367d1d1633cef (diff)
downloadspark-551def5d6972440365bd7436d484a67138d9a8f3.tar.gz
spark-551def5d6972440365bd7436d484a67138d9a8f3.tar.bz2
spark-551def5d6972440365bd7436d484a67138d9a8f3.zip
[SPARK-9789] [ML] Added logreg threshold param back
Reinstated LogisticRegression.threshold Param for binary compatibility. Param thresholds overrides threshold, if set. CC: mengxr dbtsai feynmanliang Author: Joseph K. Bradley <joseph@databricks.com> Closes #8079 from jkbradley/logreg-reinstate-threshold.
Diffstat (limited to 'mllib/src/test/java/org')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java7
1 files changed, 2 insertions, 5 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 7e9aa38372..618b95b9bd 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -100,9 +100,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
assert(r.getDouble(0) == 0.0);
}
// Call transform with params, and check that the params worked.
- double[] thresholds = {1.0, 0.0};
- model.transform(
- dataset, model.thresholds().w(thresholds), model.probabilityCol().w("myProb"))
+ model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
.registerTempTable("predNotAllZero");
DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
boolean foundNonZero = false;
@@ -112,9 +110,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
assert(foundNonZero);
// Call fit() with new params, and check as many params as we can.
- double[] thresholds2 = {0.6, 0.4};
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
- lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb"));
+ lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
LogisticRegression parent2 = (LogisticRegression) model2.parent();
assert(parent2.getMaxIter() == 5);
assert(parent2.getRegParam() == 0.1);