aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala9
1 files changed, 9 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index b3d1bfcfbe..35d8c2e16c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -46,6 +46,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(lr.getPredictionCol == "prediction")
assert(lr.getRawPredictionCol == "rawPrediction")
assert(lr.getProbabilityCol == "probability")
+ assert(lr.getFitIntercept == true)
val model = lr.fit(dataset)
model.transform(dataset)
.select("label", "probability", "prediction", "rawPrediction")
@@ -55,6 +56,14 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(model.getPredictionCol == "prediction")
assert(model.getRawPredictionCol == "rawPrediction")
assert(model.getProbabilityCol == "probability")
+ assert(model.intercept !== 0.0)
+ }
+
+ test("logistic regression doesn't fit intercept when fitIntercept is off") {
+ val lr = new LogisticRegression
+ lr.setFitIntercept(false)
+ val model = lr.fit(dataset)
+ assert(model.intercept === 0.0)
}
test("logistic regression with setters") {