aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-06-23 12:42:17 -0700
committerDB Tsai <dbt@netflix.com>2015-06-23 12:42:17 -0700
commit2b1111dd0b8deb9ad8d43fec792e60e3d0c4de75 (patch)
treef0bf6fad08cdd7447a3ecf1b6fb8368e617cb520 /mllib
parent6f4cadf5ee81467d077febc53d36571dd232295d (diff)
downloadspark-2b1111dd0b8deb9ad8d43fec792e60e3d0c4de75.tar.gz
spark-2b1111dd0b8deb9ad8d43fec792e60e3d0c4de75.tar.bz2
spark-2b1111dd0b8deb9ad8d43fec792e60e3d0c4de75.zip
[SPARK-7888] Be able to disable intercept in linear regression in ml package
Author: Holden Karau <holden@pigscanfly.ca> Closes #6927 from holdenk/SPARK-7888-Be-able-to-disable-intercept-in-Linear-Regression-in-ML-package and squashes the following commits: 0ad384c [Holden Karau] Add MiMa excludes 4016fac [Holden Karau] Switch to wild card import, remove extra blank lines ae5baa8 [Holden Karau] CR feedback, move the fitIntercept down rather than changing ymean and etc above f34971c [Holden Karau] Fix some more long lines 319bd3f [Holden Karau] Fix long lines 3bb9ee1 [Holden Karau] Update the regression suite tests 7015b9f [Holden Karau] Our code performs the same with R, except we need more than one data point but that seems reasonable 0b0c8c0 [Holden Karau] fix the issue with the sample R code e2140ba [Holden Karau] Add a test, it fails! 5e84a0b [Holden Karau] Write out thoughts and use the correct trait 91ffc0a [Holden Karau] more murh 006246c [Holden Karau] murp?
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala30
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala149
2 files changed, 167 insertions, 12 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 01306545fc..1b1d7299fb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -26,7 +26,7 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
+import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
@@ -41,7 +41,8 @@ import org.apache.spark.util.StatCounter
* Params for linear regression.
*/
private[regression] trait LinearRegressionParams extends PredictorParams
- with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
+ with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
+ with HasFitIntercept
/**
* :: Experimental ::
@@ -73,6 +74,14 @@ class LinearRegression(override val uid: String)
setDefault(regParam -> 0.0)
/**
+ * Set if we should fit the intercept
+ * Default is true.
+ * @group setParam
+ */
+ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+ setDefault(fitIntercept -> true)
+
+ /**
* Set the ElasticNet mixing parameter.
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
* For 0 < alpha < 1, the penalty is a combination of L1 and L2.
@@ -123,6 +132,7 @@ class LinearRegression(override val uid: String)
val numFeatures = summarizer.mean.size
val yMean = statCounter.mean
val yStd = math.sqrt(statCounter.variance)
+ // look at glmnet5.m L761 maaaybe that has info
// If the yStd is zero, then the intercept is yMean with zero weights;
// as a result, training is not needed.
@@ -142,7 +152,7 @@ class LinearRegression(override val uid: String)
val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam
val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
- val costFun = new LeastSquaresCostFun(instances, yStd, yMean,
+ val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
featuresStd, featuresMean, effectiveL2RegParam)
val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
@@ -180,7 +190,7 @@ class LinearRegression(override val uid: String)
// The intercept in R's GLMNET is computed using closed form after the coefficients are
// converged. See the following discussion for detail.
// http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
- val intercept = yMean - dot(weights, Vectors.dense(featuresMean))
+ val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0
if (handlePersistence) instances.unpersist()
// TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
@@ -234,6 +244,7 @@ class LinearRegressionModel private[ml] (
* See this discussion for detail.
* http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
*
+ * When training with intercept enabled,
* The objective function in the scaled space is given by
* {{{
* L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2,
@@ -241,6 +252,10 @@ class LinearRegressionModel private[ml] (
* where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i,
* \bar{y} is the mean of label, and \hat{y} is the standard deviation of label.
*
+ * If we fitting the intercept disabled (that is forced through 0.0),
+ * we can use the same equation except we set \bar{y} and \bar{x_i} to 0 instead
+ * of the respective means.
+ *
* This can be rewritten as
* {{{
* L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y}
@@ -255,6 +270,7 @@ class LinearRegressionModel private[ml] (
* \sum_i w_i^\prime x_i - y / \hat{y} + offset
* }}}
*
+ *
* Note that the effective weights and offset don't depend on training dataset,
* so they can be precomputed.
*
@@ -301,6 +317,7 @@ private class LeastSquaresAggregator(
weights: Vector,
labelStd: Double,
labelMean: Double,
+ fitIntercept: Boolean,
featuresStd: Array[Double],
featuresMean: Array[Double]) extends Serializable {
@@ -321,7 +338,7 @@ private class LeastSquaresAggregator(
}
i += 1
}
- (weightsArray, -sum + labelMean / labelStd, weightsArray.length)
+ (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length)
}
private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
@@ -404,6 +421,7 @@ private class LeastSquaresCostFun(
data: RDD[(Double, Vector)],
labelStd: Double,
labelMean: Double,
+ fitIntercept: Boolean,
featuresStd: Array[Double],
featuresMean: Array[Double],
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
@@ -412,7 +430,7 @@ private class LeastSquaresCostFun(
val w = Vectors.fromBreeze(weights)
val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd,
- labelMean, featuresStd, featuresMean))(
+ labelMean, fitIntercept, featuresStd, featuresMean))(
seqOp = (c, v) => (c, v) match {
case (aggregator, (label, features)) => aggregator.add(label, features)
},
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 732e2c42be..ad1e9da692 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, Row}
class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _
+ @transient var datasetWithoutIntercept: DataFrame = _
/**
* In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
@@ -34,14 +35,24 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
*
* import org.apache.spark.mllib.util.LinearDataGenerator
* val data =
- * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2)
- * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path")
+ * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2),
+ * Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)
+ * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1)
+ * .saveAsTextFile("path")
*/
override def beforeAll(): Unit = {
super.beforeAll()
dataset = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
+ /**
+ * datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
+ * training model without intercept
+ */
+ datasetWithoutIntercept = sqlContext.createDataFrame(
+ sc.parallelize(LinearDataGenerator.generateLinearInput(
+ 0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
+
}
test("linear regression with intercept without regularization") {
@@ -78,6 +89,42 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("linear regression without intercept without regularization") {
+ val trainer = (new LinearRegression).setFitIntercept(false)
+ val model = trainer.fit(dataset)
+ val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept)
+
+ /**
+ * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
+ * intercept = FALSE))
+ * > weights
+ * 3 x 1 sparse Matrix of class "dgCMatrix"
+ * s0
+ * (Intercept) .
+ * as.numeric.data.V2. 6.995908
+ * as.numeric.data.V3. 5.275131
+ */
+ val weightsR = Array(6.995908, 5.275131)
+
+ assert(model.intercept ~== 0 relTol 1E-3)
+ assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
+ assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+ /**
+ * Then again with the data with no intercept:
+ * > weightsWithoutIntercept
+ * 3 x 1 sparse Matrix of class "dgCMatrix"
+ * s0
+ * (Intercept) .
+ * as.numeric.data3.V2. 4.70011
+ * as.numeric.data3.V3. 7.19943
+ */
+ val weightsWithoutInterceptR = Array(4.70011, 7.19943)
+
+ assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3)
+ assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3)
+ assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3)
+ }
+
test("linear regression with intercept with L1 regularization") {
val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
val model = trainer.fit(dataset)
@@ -87,11 +134,11 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
* > weights
* 3 x 1 sparse Matrix of class "dgCMatrix"
* s0
- * (Intercept) 6.311546
- * as.numeric.data.V2. 2.123522
- * as.numeric.data.V3. 4.605651
+ * (Intercept) 6.24300
+ * as.numeric.data.V2. 4.024821
+ * as.numeric.data.V3. 6.679841
*/
- val interceptR = 6.243000
+ val interceptR = 6.24300
val weightsR = Array(4.024821, 6.679841)
assert(model.intercept ~== interceptR relTol 1E-3)
@@ -106,6 +153,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("linear regression without intercept with L1 regularization") {
+ val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+ .setFitIntercept(false)
+ val model = trainer.fit(dataset)
+
+ /**
+ * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
+ * intercept=FALSE))
+ * > weights
+ * 3 x 1 sparse Matrix of class "dgCMatrix"
+ * s0
+ * (Intercept) .
+ * as.numeric.data.V2. 6.299752
+ * as.numeric.data.V3. 4.772913
+ */
+ val interceptR = 0.0
+ val weightsR = Array(6.299752, 4.772913)
+
+ assert(model.intercept ~== interceptR relTol 1E-3)
+ assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
+ assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 =
+ features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
+
test("linear regression with intercept with L2 regularization") {
val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
val model = trainer.fit(dataset)
@@ -134,6 +211,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("linear regression without intercept with L2 regularization") {
+ val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+ .setFitIntercept(false)
+ val model = trainer.fit(dataset)
+
+ /**
+ * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
+ * intercept = FALSE))
+ * > weights
+ * 3 x 1 sparse Matrix of class "dgCMatrix"
+ * s0
+ * (Intercept) .
+ * as.numeric.data.V2. 5.522875
+ * as.numeric.data.V3. 4.214502
+ */
+ val interceptR = 0.0
+ val weightsR = Array(5.522875, 4.214502)
+
+ assert(model.intercept ~== interceptR relTol 1E-3)
+ assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
+ assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 =
+ features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
+
test("linear regression with intercept with ElasticNet regularization") {
val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
val model = trainer.fit(dataset)
@@ -161,4 +268,34 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
+
+ test("linear regression without intercept with ElasticNet regularization") {
+ val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+ .setFitIntercept(false)
+ val model = trainer.fit(dataset)
+
+ /**
+ * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
+ * intercept=FALSE))
+ * > weights
+ * 3 x 1 sparse Matrix of class "dgCMatrix"
+ * s0
+ * (Intercept) .
+ * as.numeric.dataM.V2. 5.673348
+ * as.numeric.dataM.V3. 4.322251
+ */
+ val interceptR = 0.0
+ val weightsR = Array(5.673348, 4.322251)
+
+ assert(model.intercept ~== interceptR relTol 1E-3)
+ assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
+ assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 =
+ features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
}