aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala9
3 files changed, 27 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 49c00f7748..34625745dd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel
* Params for logistic regression.
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
- with HasRegParam with HasMaxIter with HasThreshold
+ with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold
/**
@@ -56,6 +56,9 @@ class LogisticRegression
def setMaxIter(value: Int): this.type = set(maxIter, value)
/** @group setParam */
+ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+
+ /** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
@@ -67,7 +70,8 @@ class LogisticRegression
}
// Train model
- val lr = new LogisticRegressionWithLBFGS
+ val lr = new LogisticRegressionWithLBFGS()
+ .setIntercept(paramMap(fitIntercept))
lr.optimizer
.setRegParam(paramMap(regParam))
.setNumIterations(paramMap(maxIter))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
index 5d660d1e15..0739fdbfcb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -106,6 +106,18 @@ private[ml] trait HasProbabilityCol extends Params {
def getProbabilityCol: String = get(probabilityCol)
}
+private[ml] trait HasFitIntercept extends Params {
+ /**
+ * param for fitting the intercept term, defaults to true
+ * @group param
+ */
+ val fitIntercept: BooleanParam =
+ new BooleanParam(this, "fitIntercept", "indicates whether to fit an intercept term", Some(true))
+
+ /** @group getParam */
+ def getFitIntercept: Boolean = get(fitIntercept)
+}
+
private[ml] trait HasThreshold extends Params {
/**
* param for threshold in (binary) prediction
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index b3d1bfcfbe..35d8c2e16c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -46,6 +46,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(lr.getPredictionCol == "prediction")
assert(lr.getRawPredictionCol == "rawPrediction")
assert(lr.getProbabilityCol == "probability")
+ assert(lr.getFitIntercept == true)
val model = lr.fit(dataset)
model.transform(dataset)
.select("label", "probability", "prediction", "rawPrediction")
@@ -55,6 +56,14 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(model.getPredictionCol == "prediction")
assert(model.getRawPredictionCol == "rawPrediction")
assert(model.getProbabilityCol == "probability")
+ assert(model.intercept !== 0.0)
+ }
+
+ test("logistic regression doesn't fit intercept when fitIntercept is off") {
+ val lr = new LogisticRegression
+ lr.setFitIntercept(false)
+ val model = lr.fit(dataset)
+ assert(model.intercept === 0.0)
}
test("logistic regression with setters") {