diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-08-12 14:27:13 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-08-12 14:27:13 -0700 |
commit | 551def5d6972440365bd7436d484a67138d9a8f3 (patch) | |
tree | af2280c3849497b4236099ec84fe7b4b64d63f2e /mllib/src/test/java/org | |
parent | 762bacc16ac5e74c8b05a7c1e3e367d1d1633cef (diff) | |
download | spark-551def5d6972440365bd7436d484a67138d9a8f3.tar.gz spark-551def5d6972440365bd7436d484a67138d9a8f3.tar.bz2 spark-551def5d6972440365bd7436d484a67138d9a8f3.zip |
[SPARK-9789] [ML] Added logreg threshold param back
Reinstated LogisticRegression.threshold Param for binary compatibility. Param thresholds overrides threshold, if set.
CC: mengxr dbtsai feynmanliang
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #8079 from jkbradley/logreg-reinstate-threshold.
Diffstat (limited to 'mllib/src/test/java/org')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java | 7 |
1 files changed, 2 insertions, 5 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 7e9aa38372..618b95b9bd 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -100,9 +100,7 @@ public class JavaLogisticRegressionSuite implements Serializable { assert(r.getDouble(0) == 0.0); } // Call transform with params, and check that the params worked. - double[] thresholds = {1.0, 0.0}; - model.transform( - dataset, model.thresholds().w(thresholds), model.probabilityCol().w("myProb")) + model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; @@ -112,9 +110,8 @@ public class JavaLogisticRegressionSuite implements Serializable { assert(foundNonZero); // Call fit() with new params, and check as many params as we can. - double[] thresholds2 = {0.6, 0.4}; LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), - lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb")); + lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); assert(parent2.getMaxIter() == 5); assert(parent2.getRegParam() == 0.1); |