diff options
author | Omede Firouz <ofirouz@palantir.com> | 2015-04-07 23:36:31 -0400 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-04-07 23:36:31 -0400 |
commit | d138aa8ee23f4450242da3ac70a493229a90c76b (patch) | |
tree | 059bc1504106aba35d1b6cac5e8d428066f022cd /mllib/src/test | |
parent | c83e03948b184ffb3a9418fecc4d2c26ae33b057 (diff) | |
download | spark-d138aa8ee23f4450242da3ac70a493229a90c76b.tar.gz spark-d138aa8ee23f4450242da3ac70a493229a90c76b.tar.bz2 spark-d138aa8ee23f4450242da3ac70a493229a90c76b.zip |
[SPARK-6705][MLLIB] Add fit intercept api to ml logisticregression
I have the fit intercept enabled by default for logistic regression, I
wonder what others think here. I understand that it enables allocation
by default which is undesirable, but one needs to have a very strong
reason for not having an intercept term enabled so it is the safer
default from a statistical sense.
Explicitly modeling the intercept by adding a column of all 1s does not
work. I believe the reason is that since the API for
LogisticRegressionWithLBFGS forces column normalization, and a column of all
1s has 0 variance so dividing by 0 kills it.
Author: Omede Firouz <ofirouz@palantir.com>
Closes #5301 from oefirouz/addIntercept and squashes the following commits:
9f1286b [Omede Firouz] [SPARK-6705][MLLIB] Add fitInterceptTerm to LogisticRegression
1d6bd6f [Omede Firouz] [SPARK-6705][MLLIB] Add a fit intercept term to ML LogisticRegression
9963509 [Omede Firouz] [MLLIB] Add fitIntercept to LogisticRegression
2257fca [Omede Firouz] [MLLIB] Add fitIntercept param to logistic regression
329c1e2 [Omede Firouz] [MLLIB] Add fit intercept term
bd9663c [Omede Firouz] [MLLIB] Add fit intercept api to ml logisticregression
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index b3d1bfcfbe..35d8c2e16c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -46,6 +46,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(lr.getPredictionCol == "prediction") assert(lr.getRawPredictionCol == "rawPrediction") assert(lr.getProbabilityCol == "probability") + assert(lr.getFitIntercept == true) val model = lr.fit(dataset) model.transform(dataset) .select("label", "probability", "prediction", "rawPrediction") @@ -55,6 +56,14 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(model.getPredictionCol == "prediction") assert(model.getRawPredictionCol == "rawPrediction") assert(model.getProbabilityCol == "probability") + assert(model.intercept !== 0.0) + } + + test("logistic regression doesn't fit intercept when fitIntercept is off") { + val lr = new LogisticRegression + lr.setFitIntercept(false) + val model = lr.fit(dataset) + assert(model.intercept === 0.0) } test("logistic regression with setters") { |