aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-08-04 18:15:26 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-04 18:15:26 -0700
commitd92fa14179287c996407d9c7d249103109f9cdef (patch)
treec7b0cf69980dbafad77fa675d53270d49680ba1a /mllib
parent629e26f7ee916e70f59b017cb6083aa441b26b2c (diff)
downloadspark-d92fa14179287c996407d9c7d249103109f9cdef.tar.gz
spark-d92fa14179287c996407d9c7d249103109f9cdef.tar.bz2
spark-d92fa14179287c996407d9c7d249103109f9cdef.zip
[SPARK-8601] [ML] Add an option to disable standardization for linear regression
All compressed sensing applications, and some of the regression use-cases will have better result by turning the feature scaling off. However, if we implement this naively by training the dataset without doing any standardization, the rate of convergency will not be good. This can be implemented by still standardizing the training dataset but we penalize each component differently to get effectively the same objective function but a better numerical problem. As a result, for those columns with high variances, they will be penalized less, and vice versa. Without this, since all the features are standardized, so they will be penalized the same. In R, there is an option for this. standardize Logical flag for x variable standardization, prior to fitting the model sequence. The coefficients are always returned on the original scale. Default is standardize=TRUE. If variables are in the same units already, you might not wish to standardize. See details below for y standardization with family="gaussian". Note that the primary author for this PR is holdenk Author: Holden Karau <holden@pigscanfly.ca> Author: DB Tsai <dbt@netflix.com> Closes #7875 from dbtsai/SPARK-8522 and squashes the following commits: e856036 [DB Tsai] scala doc 596e96c [DB Tsai] minor bbff347 [DB Tsai] naming baa0805 [DB Tsai] touch up d6234ba [DB Tsai] Merge branch 'master' into SPARK-8522-Disable-Linear_featureScaling-Spark-8601-in-Linear_regression 6b1dc09 [Holden Karau] Merge branch 'master' into SPARK-8522-Disable-Linear_featureScaling-Spark-8601-in-Linear_regression 332f140 [Holden Karau] Merge in master eebe10a [Holden Karau] Use same comparision operator throughout the test 3f92935 [Holden Karau] merge b83a41e [Holden Karau] Expand the tests and make them similar to the other PR also providing an option to disable standardization (but for LoR). 0c334a2 [Holden Karau] Remove extra line 99ce053 [Holden Karau] merge in master e54a8a9 [Holden Karau] Fix long line e47c574 [Holden Karau] Add support for L2 without standardization. 55d3a66 [Holden Karau] Add standardization param for linear regression 00a1dc5 [Holden Karau] Add the param to the linearregression impl
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala70
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala278
3 files changed, 268 insertions, 86 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index c937b9602b..0d07383925 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -133,9 +133,9 @@ class LogisticRegression(override val uid: String)
/**
* Whether to standardize the training features before fitting the model.
* The coefficients of models will be always returned on the original scale,
- * so it will be transparent for users. Note that when no regularization,
- * with or without standardization, the models should be always converged to
- * the same solution.
+ * so it will be transparent for users. Note that with/without standardization,
+ * the models should be always converged to the same solution when no regularization
+ * is applied. In R's GLMNET package, the default behavior is true as well.
* Default is true.
* @group setParam
* */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 3b85ba001b..92d819bad8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -45,7 +45,7 @@ import org.apache.spark.util.StatCounter
*/
private[regression] trait LinearRegressionParams extends PredictorParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
- with HasFitIntercept
+ with HasFitIntercept with HasStandardization
/**
* :: Experimental ::
@@ -85,6 +85,18 @@ class LinearRegression(override val uid: String)
setDefault(fitIntercept -> true)
/**
+ * Whether to standardize the training features before fitting the model.
+ * The coefficients of models will be always returned on the original scale,
+ * so it will be transparent for users. Note that with/without standardization,
+ * the models should be always converged to the same solution when no regularization
+ * is applied. In R's GLMNET package, the default behavior is true as well.
+ * Default is true.
+ * @group setParam
+ */
+ def setStandardization(value: Boolean): this.type = set(standardization, value)
+ setDefault(standardization -> true)
+
+ /**
* Set the ElasticNet mixing parameter.
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
* For 0 < alpha < 1, the penalty is a combination of L1 and L2.
@@ -165,12 +177,24 @@ class LinearRegression(override val uid: String)
val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
- featuresStd, featuresMean, effectiveL2RegParam)
+ $(standardization), featuresStd, featuresMean, effectiveL2RegParam)
val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
} else {
- new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegParam, $(tol))
+ def effectiveL1RegFun = (index: Int) => {
+ if ($(standardization)) {
+ effectiveL1RegParam
+ } else {
+ // If `standardization` is false, we still standardize the data
+ // to improve the rate of convergence; as a result, we have to
+ // perform this reverse standardization by penalizing each component
+ // differently to get effectively the same objective function when
+ // the training dataset is not standardized.
+ if (featuresStd(index) != 0.0) effectiveL1RegParam / featuresStd(index) else 0.0
+ }
+ }
+ new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, $(tol))
}
val initialWeights = Vectors.zeros(numFeatures)
@@ -456,6 +480,7 @@ class LinearRegressionSummary private[regression] (
* @param weights The weights/coefficients corresponding to the features.
* @param labelStd The standard deviation value of the label.
* @param labelMean The mean value of the label.
+ * @param fitIntercept Whether to fit an intercept term.
* @param featuresStd The standard deviation values of the features.
* @param featuresMean The mean values of the features.
*/
@@ -568,6 +593,7 @@ private class LeastSquaresCostFun(
labelStd: Double,
labelMean: Double,
fitIntercept: Boolean,
+ standardization: Boolean,
featuresStd: Array[Double],
featuresMean: Array[Double],
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
@@ -584,14 +610,38 @@ private class LeastSquaresCostFun(
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
})
- // regVal is the sum of weight squares for L2 regularization
- val norm = brzNorm(weights, 2.0)
- val regVal = 0.5 * effectiveL2regParam * norm * norm
+ val totalGradientArray = leastSquaresAggregator.gradient.toArray
- val loss = leastSquaresAggregator.loss + regVal
- val gradient = leastSquaresAggregator.gradient
- axpy(effectiveL2regParam, w, gradient)
+ val regVal = if (effectiveL2regParam == 0.0) {
+ 0.0
+ } else {
+ var sum = 0.0
+ w.foreachActive { (index, value) =>
+ // The following code will compute the loss of the regularization; also
+ // the gradient of the regularization, and add back to totalGradientArray.
+ sum += {
+ if (standardization) {
+ totalGradientArray(index) += effectiveL2regParam * value
+ value * value
+ } else {
+ if (featuresStd(index) != 0.0) {
+ // If `standardization` is false, we still standardize the data
+ // to improve the rate of convergence; as a result, we have to
+ // perform this reverse standardization by penalizing each component
+ // differently to get effectively the same objective function when
+ // the training dataset is not standardized.
+ val temp = value / (featuresStd(index) * featuresStd(index))
+ totalGradientArray(index) += effectiveL2regParam * temp
+ value * temp
+ } else {
+ 0.0
+ }
+ }
+ }
+ }
+ 0.5 * effectiveL2regParam * sum
+ }
- (loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
+ (leastSquaresAggregator.loss + regVal, new BDV(totalGradientArray))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 7cdda3db88..21ad8225bd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -70,6 +70,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(lir.getRegParam === 0.0)
assert(lir.getElasticNetParam === 0.0)
assert(lir.getFitIntercept)
+ assert(lir.getStandardization)
val model = lir.fit(dataset)
model.transform(dataset)
.select("label", "prediction")
@@ -81,8 +82,11 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("linear regression with intercept without regularization") {
- val trainer = new LinearRegression
- val model = trainer.fit(dataset)
+ val trainer1 = new LinearRegression
+ // The result should be the same regardless of standardization without regularization
+ val trainer2 = (new LinearRegression).setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
/*
Using the following R code to load the data and train the model using glmnet package.
@@ -95,28 +99,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
> weights
3 x 1 sparse Matrix of class "dgCMatrix"
s0
- (Intercept) 6.300528
- as.numeric.data.V2. 4.701024
- as.numeric.data.V3. 7.198257
+ (Intercept) 6.298698
+ as.numeric.data.V2. 4.700706
+ as.numeric.data.V3. 7.199082
*/
val interceptR = 6.298698
val weightsR = Vectors.dense(4.700706, 7.199082)
- assert(model.intercept ~== interceptR relTol 1E-3)
- assert(model.weights ~= weightsR relTol 1E-3)
+ assert(model1.intercept ~== interceptR relTol 1E-3)
+ assert(model1.weights ~= weightsR relTol 1E-3)
+ assert(model2.intercept ~== interceptR relTol 1E-3)
+ assert(model2.weights ~= weightsR relTol 1E-3)
- model.transform(dataset).select("features", "prediction").collect().foreach {
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression without intercept without regularization") {
- val trainer = (new LinearRegression).setFitIntercept(false)
- val model = trainer.fit(dataset)
- val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept)
+ val trainer1 = (new LinearRegression).setFitIntercept(false)
+ // Without regularization the results should be the same
+ val trainer2 = (new LinearRegression).setFitIntercept(false).setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val modelWithoutIntercept1 = trainer1.fit(datasetWithoutIntercept)
+ val model2 = trainer2.fit(dataset)
+ val modelWithoutIntercept2 = trainer2.fit(datasetWithoutIntercept)
+
/*
weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
@@ -130,26 +142,34 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
val weightsR = Vectors.dense(6.995908, 5.275131)
- assert(model.intercept ~== 0 absTol 1E-3)
- assert(model.weights ~= weightsR relTol 1E-3)
+ assert(model1.intercept ~== 0 absTol 1E-3)
+ assert(model1.weights ~= weightsR relTol 1E-3)
+ assert(model2.intercept ~== 0 absTol 1E-3)
+ assert(model2.weights ~= weightsR relTol 1E-3)
+
/*
Then again with the data with no intercept:
> weightsWithoutIntercept
3 x 1 sparse Matrix of class "dgCMatrix"
- s0
+ s0
(Intercept) .
as.numeric.data3.V2. 4.70011
as.numeric.data3.V3. 7.19943
*/
val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943)
- assert(modelWithoutIntercept.intercept ~== 0 absTol 1E-3)
- assert(modelWithoutIntercept.weights ~= weightsWithoutInterceptR relTol 1E-3)
+ assert(modelWithoutIntercept1.intercept ~== 0 absTol 1E-3)
+ assert(modelWithoutIntercept1.weights ~= weightsWithoutInterceptR relTol 1E-3)
+ assert(modelWithoutIntercept2.intercept ~== 0 absTol 1E-3)
+ assert(modelWithoutIntercept2.weights ~= weightsWithoutInterceptR relTol 1E-3)
}
test("linear regression with intercept with L1 regularization") {
- val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
- val model = trainer.fit(dataset)
+ val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+ val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+ .setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
/*
weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57))
@@ -160,24 +180,44 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V2. 4.024821
as.numeric.data.V3. 6.679841
*/
- val interceptR = 6.24300
- val weightsR = Vectors.dense(4.024821, 6.679841)
+ val interceptR1 = 6.24300
+ val weightsR1 = Vectors.dense(4.024821, 6.679841)
- assert(model.intercept ~== interceptR relTol 1E-3)
- assert(model.weights ~= weightsR relTol 1E-3)
+ assert(model1.intercept ~== interceptR1 relTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
- model.transform(dataset).select("features", "prediction").collect().foreach {
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
+ standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 6.416948
+ as.numeric.data.V2. 3.893869
+ as.numeric.data.V3. 6.724286
+ */
+ val interceptR2 = 6.416948
+ val weightsR2 = Vectors.dense(3.893869, 6.724286)
+
+ assert(model2.intercept ~== interceptR2 relTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression without intercept with L1 regularization") {
- val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+ val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
.setFitIntercept(false)
- val model = trainer.fit(dataset)
+ val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+ .setFitIntercept(false).setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
/*
weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
@@ -189,51 +229,90 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V2. 6.299752
as.numeric.data.V3. 4.772913
*/
- val interceptR = 0.0
- val weightsR = Vectors.dense(6.299752, 4.772913)
+ val interceptR1 = 0.0
+ val weightsR1 = Vectors.dense(6.299752, 4.772913)
- assert(model.intercept ~== interceptR absTol 1E-5)
- assert(model.weights ~= weightsR relTol 1E-3)
+ assert(model1.intercept ~== interceptR1 absTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
- model.transform(dataset).select("features", "prediction").collect().foreach {
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
+ intercept=FALSE, standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 6.232193
+ as.numeric.data.V3. 4.764229
+ */
+ val interceptR2 = 0.0
+ val weightsR2 = Vectors.dense(6.232193, 4.764229)
+
+ assert(model2.intercept ~== interceptR2 absTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression with intercept with L2 regularization") {
- val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
- val model = trainer.fit(dataset)
+ val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+ val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+ .setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
/*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) 6.328062
- as.numeric.data.V2. 3.222034
- as.numeric.data.V3. 4.926260
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 5.269376
+ as.numeric.data.V2. 3.736216
+ as.numeric.data.V3. 5.712356)
*/
- val interceptR = 5.269376
- val weightsR = Vectors.dense(3.736216, 5.712356)
+ val interceptR1 = 5.269376
+ val weightsR1 = Vectors.dense(3.736216, 5.712356)
- assert(model.intercept ~== interceptR relTol 1E-3)
- assert(model.weights ~= weightsR relTol 1E-3)
+ assert(model1.intercept ~== interceptR1 relTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
- model.transform(dataset).select("features", "prediction").collect().foreach {
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
+ standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 5.791109
+ as.numeric.data.V2. 3.435466
+ as.numeric.data.V3. 5.910406
+ */
+ val interceptR2 = 5.791109
+ val weightsR2 = Vectors.dense(3.435466, 5.910406)
+
+ assert(model2.intercept ~== interceptR2 relTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression without intercept with L2 regularization") {
- val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+ val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
.setFitIntercept(false)
- val model = trainer.fit(dataset)
+ val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+ .setFitIntercept(false).setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
/*
weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
@@ -245,23 +324,42 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V2. 5.522875
as.numeric.data.V3. 4.214502
*/
- val interceptR = 0.0
- val weightsR = Vectors.dense(5.522875, 4.214502)
+ val interceptR1 = 0.0
+ val weightsR1 = Vectors.dense(5.522875, 4.214502)
- assert(model.intercept ~== interceptR absTol 1E-3)
- assert(model.weights ~== weightsR relTol 1E-3)
+ assert(model1.intercept ~== interceptR1 absTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
- model.transform(dataset).select("features", "prediction").collect().foreach {
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
+ intercept = FALSE, standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 5.263704
+ as.numeric.data.V3. 4.187419
+ */
+ val interceptR2 = 0.0
+ val weightsR2 = Vectors.dense(5.263704, 4.187419)
+
+ assert(model2.intercept ~== interceptR2 absTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression with intercept with ElasticNet regularization") {
- val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
- val model = trainer.fit(dataset)
+ val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+ val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+ .setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
/*
weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6))
@@ -272,24 +370,43 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V2. 3.168435
as.numeric.data.V3. 5.200403
*/
- val interceptR = 5.696056
- val weightsR = Vectors.dense(3.670489, 6.001122)
+ val interceptR1 = 5.696056
+ val weightsR1 = Vectors.dense(3.670489, 6.001122)
- assert(model.intercept ~== interceptR relTol 1E-3)
- assert(model.weights ~== weightsR relTol 1E-3)
+ assert(model1.intercept ~== interceptR1 relTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
- model.transform(dataset).select("features", "prediction").collect().foreach {
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6
+ standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 6.114723
+ as.numeric.data.V2. 3.409937
+ as.numeric.data.V3. 6.146531
+ */
+ val interceptR2 = 6.114723
+ val weightsR2 = Vectors.dense(3.409937, 6.146531)
+
+ assert(model2.intercept ~== interceptR2 relTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression without intercept with ElasticNet regularization") {
- val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+ val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
.setFitIntercept(false)
- val model = trainer.fit(dataset)
+ val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+ .setFitIntercept(false).setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
/*
weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
@@ -301,16 +418,32 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.dataM.V2. 5.673348
as.numeric.dataM.V3. 4.322251
*/
- val interceptR = 0.0
- val weightsR = Vectors.dense(5.673348, 4.322251)
+ val interceptR1 = 0.0
+ val weightsR1 = Vectors.dense(5.673348, 4.322251)
- assert(model.intercept ~== interceptR absTol 1E-3)
- assert(model.weights ~= weightsR relTol 1E-3)
+ assert(model1.intercept ~== interceptR1 absTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
- model.transform(dataset).select("features", "prediction").collect().foreach {
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
+ intercept=FALSE, standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 5.477988
+ as.numeric.data.V3. 4.297622
+ */
+ val interceptR2 = 0.0
+ val weightsR2 = Vectors.dense(5.477988, 4.297622)
+
+ assert(model2.intercept ~== interceptR2 absTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
@@ -372,5 +505,4 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.zip(testSummary.residuals.select("residuals").collect())
.forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 }
}
-
}