aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDB Tsai <dbt@netflix.com>2015-07-08 15:21:58 -0700
committerDB Tsai <dbt@netflix.com>2015-07-08 15:21:58 -0700
commit57221934e0376e5bb8421dc35d4bf91db4deeca1 (patch)
treed7736dda417fa4dae7b61c1bfa63da62413cb030
parent00b265f12c0f0271b7036f831fee09b694908b29 (diff)
downloadspark-57221934e0376e5bb8421dc35d4bf91db4deeca1.tar.gz
spark-57221934e0376e5bb8421dc35d4bf91db4deeca1.tar.bz2
spark-57221934e0376e5bb8421dc35d4bf91db4deeca1.zip
[SPARK-8700][ML] Disable feature scaling in Logistic Regression
All compressed sensing applications, and some of the regression use-cases will have better result by turning the feature scaling off. However, if we implement this naively by training the dataset without doing any standardization, the rate of convergency will not be good. This can be implemented by still standardizing the training dataset but we penalize each component differently to get effectively the same objective function but a better numerical problem. As a result, for those columns with high variances, they will be penalized less, and vice versa. Without this, since all the features are standardized, so they will be penalized the same. In R, there is an option for this. `standardize` Logical flag for x variable standardization, prior to fitting the model sequence. The coefficients are always returned on the original scale. Default is standardize=TRUE. If variables are in the same units already, you might not wish to standardize. See details below for y standardization with family="gaussian". +cc holdenk mengxr jkbradley Author: DB Tsai <dbt@netflix.com> Closes #7080 from dbtsai/lors and squashes the following commits: 877e6c7 [DB Tsai] repahse the doc 7cf45f2 [DB Tsai] address feedback 78d75c9 [DB Tsai] small change c2c9e60 [DB Tsai] style 6e1a8e0 [DB Tsai] first commit
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala89
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala403
-rw-r--r--project/MimaExcludes.scala2
5 files changed, 384 insertions, 117 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 3967151f76..8fc9199fb4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.classification
import scala.collection.mutable
-import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
+import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
import org.apache.spark.{Logging, SparkException}
@@ -41,7 +41,7 @@ import org.apache.spark.storage.StorageLevel
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
- with HasThreshold
+ with HasThreshold with HasStandardization
/**
* :: Experimental ::
@@ -98,6 +98,18 @@ class LogisticRegression(override val uid: String)
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)
+ /**
+ * Whether to standardize the training features before fitting the model.
+ * The coefficients of models will be always returned on the original scale,
+ * so it will be transparent for users. Note that when no regularization,
+ * with or without standardization, the models should be always converged to
+ * the same solution.
+ * Default is true.
+ * @group setParam
+ * */
+ def setStandardization(value: Boolean): this.type = set(standardization, value)
+ setDefault(standardization -> true)
+
/** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
setDefault(threshold -> 0.5)
@@ -149,15 +161,28 @@ class LogisticRegression(override val uid: String)
val regParamL1 = $(elasticNetParam) * $(regParam)
val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
- val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
+ val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), $(standardization),
featuresStd, featuresMean, regParamL2)
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
} else {
- // Remove the L1 penalization on the intercept
def regParamL1Fun = (index: Int) => {
- if (index == numFeatures) 0.0 else regParamL1
+ // Remove the L1 penalization on the intercept
+ if (index == numFeatures) {
+ 0.0
+ } else {
+ if ($(standardization)) {
+ regParamL1
+ } else {
+ // If `standardization` is false, we still standardize the data
+ // to improve the rate of convergence; as a result, we have to
+ // perform this reverse standardization by penalizing each component
+ // differently to get effectively the same objective function when
+ // the training dataset is not standardized.
+ if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0
+ }
+ }
}
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
}
@@ -523,11 +548,13 @@ private class LogisticCostFun(
data: RDD[(Double, Vector)],
numClasses: Int,
fitIntercept: Boolean,
+ standardization: Boolean,
featuresStd: Array[Double],
featuresMean: Array[Double],
regParamL2: Double) extends DiffFunction[BDV[Double]] {
override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
+ val numFeatures = featuresStd.length
val w = Vectors.fromBreeze(weights)
val logisticAggregator = data.treeAggregate(new LogisticAggregator(w, numClasses, fitIntercept,
@@ -539,27 +566,43 @@ private class LogisticCostFun(
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
})
- // regVal is the sum of weight squares for L2 regularization
- val norm = if (regParamL2 == 0.0) {
- 0.0
- } else if (fitIntercept) {
- brzNorm(Vectors.dense(weights.toArray.slice(0, weights.size -1)).toBreeze, 2.0)
- } else {
- brzNorm(weights, 2.0)
- }
- val regVal = 0.5 * regParamL2 * norm * norm
+ val totalGradientArray = logisticAggregator.gradient.toArray
- val loss = logisticAggregator.loss + regVal
- val gradient = logisticAggregator.gradient
-
- if (fitIntercept) {
- val wArray = w.toArray.clone()
- wArray(wArray.length - 1) = 0.0
- axpy(regParamL2, Vectors.dense(wArray), gradient)
+ // regVal is the sum of weight squares excluding intercept for L2 regularization.
+ val regVal = if (regParamL2 == 0.0) {
+ 0.0
} else {
- axpy(regParamL2, w, gradient)
+ var sum = 0.0
+ w.foreachActive { (index, value) =>
+ // If `fitIntercept` is true, the last term which is intercept doesn't
+ // contribute to the regularization.
+ if (index != numFeatures) {
+ // The following code will compute the loss of the regularization; also
+ // the gradient of the regularization, and add back to totalGradientArray.
+ sum += {
+ if (standardization) {
+ totalGradientArray(index) += regParamL2 * value
+ value * value
+ } else {
+ if (featuresStd(index) != 0.0) {
+ // If `standardization` is false, we still standardize the data
+ // to improve the rate of convergence; as a result, we have to
+ // perform this reverse standardization by penalizing each component
+ // differently to get effectively the same objective function when
+ // the training dataset is not standardized.
+ val temp = value / (featuresStd(index) * featuresStd(index))
+ totalGradientArray(index) += regParamL2 * temp
+ value * temp
+ } else {
+ 0.0
+ }
+ }
+ }
+ }
+ }
+ 0.5 * regParamL2 * sum
}
- (loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
+ (logisticAggregator.loss + regVal, new BDV(totalGradientArray))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index b0a6af171c..66b751a1b0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -54,8 +54,7 @@ private[shared] object SharedParamsCodeGen {
isValid = "ParamValidators.gtEq(1)"),
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
- " prior to fitting the model sequence. Note that the coefficients of models are" +
- " always returned on the original scale.", Some("true")),
+ " before fitting the model.", Some("true")),
ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),
ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." +
" For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.",
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index bbe08939b6..f81bd76c22 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -239,10 +239,10 @@ private[ml] trait HasFitIntercept extends Params {
private[ml] trait HasStandardization extends Params {
/**
- * Param for whether to standardize the training features prior to fitting the model sequence. Note that the coefficients of models are always returned on the original scale..
+ * Param for whether to standardize the training features before fitting the model..
* @group param
*/
- final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features prior to fitting the model sequence. Note that the coefficients of models are always returned on the original scale.")
+ final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model.")
setDefault(standardization, true)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index ba8fbee841..27253c1db2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -77,6 +77,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(lr.getRawPredictionCol === "rawPrediction")
assert(lr.getProbabilityCol === "probability")
assert(lr.getFitIntercept)
+ assert(lr.getStandardization)
val model = lr.fit(dataset)
model.transform(dataset)
.select("label", "probability", "prediction", "rawPrediction")
@@ -208,8 +209,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("binary logistic regression with intercept without regularization") {
- val trainer = (new LogisticRegression).setFitIntercept(true)
- val model = trainer.fit(binaryDataset)
+ val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true)
+ val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false)
+
+ val model1 = trainer1.fit(binaryDataset)
+ val model2 = trainer2.fit(binaryDataset)
/*
Using the following R code to load the data and train the model using glmnet package.
@@ -232,16 +236,26 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val interceptR = 2.8366423
val weightsR = Array(-0.5895848, 0.8931147, -0.3925051, -0.7996864)
- assert(model.intercept ~== interceptR relTol 1E-3)
- assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
- assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
- assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
- assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
+ assert(model1.intercept ~== interceptR relTol 1E-3)
+ assert(model1.weights(0) ~== weightsR(0) relTol 1E-3)
+ assert(model1.weights(1) ~== weightsR(1) relTol 1E-3)
+ assert(model1.weights(2) ~== weightsR(2) relTol 1E-3)
+ assert(model1.weights(3) ~== weightsR(3) relTol 1E-3)
+
+ // Without regularization, with or without standardization will converge to the same solution.
+ assert(model2.intercept ~== interceptR relTol 1E-3)
+ assert(model2.weights(0) ~== weightsR(0) relTol 1E-3)
+ assert(model2.weights(1) ~== weightsR(1) relTol 1E-3)
+ assert(model2.weights(2) ~== weightsR(2) relTol 1E-3)
+ assert(model2.weights(3) ~== weightsR(3) relTol 1E-3)
}
test("binary logistic regression without intercept without regularization") {
- val trainer = (new LogisticRegression).setFitIntercept(false)
- val model = trainer.fit(binaryDataset)
+ val trainer1 = (new LogisticRegression).setFitIntercept(false).setStandardization(true)
+ val trainer2 = (new LogisticRegression).setFitIntercept(false).setStandardization(false)
+
+ val model1 = trainer1.fit(binaryDataset)
+ val model2 = trainer2.fit(binaryDataset)
/*
Using the following R code to load the data and train the model using glmnet package.
@@ -265,17 +279,28 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val interceptR = 0.0
val weightsR = Array(-0.3534996, 1.2964482, -0.3571741, -0.7407946)
- assert(model.intercept ~== interceptR relTol 1E-3)
- assert(model.weights(0) ~== weightsR(0) relTol 1E-2)
- assert(model.weights(1) ~== weightsR(1) relTol 1E-2)
- assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
- assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
+ assert(model1.intercept ~== interceptR relTol 1E-3)
+ assert(model1.weights(0) ~== weightsR(0) relTol 1E-2)
+ assert(model1.weights(1) ~== weightsR(1) relTol 1E-2)
+ assert(model1.weights(2) ~== weightsR(2) relTol 1E-3)
+ assert(model1.weights(3) ~== weightsR(3) relTol 1E-3)
+
+ // Without regularization, with or without standardization should converge to the same solution.
+ assert(model2.intercept ~== interceptR relTol 1E-3)
+ assert(model2.weights(0) ~== weightsR(0) relTol 1E-2)
+ assert(model2.weights(1) ~== weightsR(1) relTol 1E-2)
+ assert(model2.weights(2) ~== weightsR(2) relTol 1E-3)
+ assert(model2.weights(3) ~== weightsR(3) relTol 1E-3)
}
test("binary logistic regression with intercept with L1 regularization") {
- val trainer = (new LogisticRegression).setFitIntercept(true)
- .setElasticNetParam(1.0).setRegParam(0.12)
- val model = trainer.fit(binaryDataset)
+ val trainer1 = (new LogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true)
+ val trainer2 = (new LogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false)
+
+ val model1 = trainer1.fit(binaryDataset)
+ val model2 = trainer2.fit(binaryDataset)
/*
Using the following R code to load the data and train the model using glmnet package.
@@ -295,20 +320,52 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
data.V4 -0.04325749
data.V5 -0.02481551
*/
- val interceptR = -0.05627428
- val weightsR = Array(0.0, 0.0, -0.04325749, -0.02481551)
-
- assert(model.intercept ~== interceptR relTol 1E-2)
- assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
- assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
- assert(model.weights(2) ~== weightsR(2) relTol 1E-2)
- assert(model.weights(3) ~== weightsR(3) relTol 2E-2)
+ val interceptR1 = -0.05627428
+ val weightsR1 = Array(0.0, 0.0, -0.04325749, -0.02481551)
+
+ assert(model1.intercept ~== interceptR1 relTol 1E-2)
+ assert(model1.weights(0) ~== weightsR1(0) absTol 1E-3)
+ assert(model1.weights(1) ~== weightsR1(1) absTol 1E-3)
+ assert(model1.weights(2) ~== weightsR1(2) relTol 1E-2)
+ assert(model1.weights(3) ~== weightsR1(3) relTol 2E-2)
+
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ library("glmnet")
+ data <- read.csv("path", header=FALSE)
+ label = factor(data$V1)
+ features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
+ standardize=FALSE))
+ weights
+
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 0.3722152
+ data.V2 .
+ data.V3 .
+ data.V4 -0.1665453
+ data.V5 .
+ */
+ val interceptR2 = 0.3722152
+ val weightsR2 = Array(0.0, 0.0, -0.1665453, 0.0)
+
+ assert(model2.intercept ~== interceptR2 relTol 1E-2)
+ assert(model2.weights(0) ~== weightsR2(0) absTol 1E-3)
+ assert(model2.weights(1) ~== weightsR2(1) absTol 1E-3)
+ assert(model2.weights(2) ~== weightsR2(2) relTol 1E-2)
+ assert(model2.weights(3) ~== weightsR2(3) absTol 1E-3)
}
test("binary logistic regression without intercept with L1 regularization") {
- val trainer = (new LogisticRegression).setFitIntercept(false)
- .setElasticNetParam(1.0).setRegParam(0.12)
- val model = trainer.fit(binaryDataset)
+ val trainer1 = (new LogisticRegression).setFitIntercept(false)
+ .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true)
+ val trainer2 = (new LogisticRegression).setFitIntercept(false)
+ .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false)
+
+ val model1 = trainer1.fit(binaryDataset)
+ val model2 = trainer2.fit(binaryDataset)
/*
Using the following R code to load the data and train the model using glmnet package.
@@ -329,20 +386,52 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
data.V4 -0.05189203
data.V5 -0.03891782
*/
- val interceptR = 0.0
- val weightsR = Array(0.0, 0.0, -0.05189203, -0.03891782)
+ val interceptR1 = 0.0
+ val weightsR1 = Array(0.0, 0.0, -0.05189203, -0.03891782)
+
+ assert(model1.intercept ~== interceptR1 relTol 1E-3)
+ assert(model1.weights(0) ~== weightsR1(0) absTol 1E-3)
+ assert(model1.weights(1) ~== weightsR1(1) absTol 1E-3)
+ assert(model1.weights(2) ~== weightsR1(2) relTol 1E-2)
+ assert(model1.weights(3) ~== weightsR1(3) relTol 1E-2)
+
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ library("glmnet")
+ data <- read.csv("path", header=FALSE)
+ label = factor(data$V1)
+ features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
+ intercept=FALSE, standardize=FALSE))
+ weights
- assert(model.intercept ~== interceptR relTol 1E-3)
- assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
- assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
- assert(model.weights(2) ~== weightsR(2) relTol 1E-2)
- assert(model.weights(3) ~== weightsR(3) relTol 1E-2)
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ data.V2 .
+ data.V3 .
+ data.V4 -0.08420782
+ data.V5 .
+ */
+ val interceptR2 = 0.0
+ val weightsR2 = Array(0.0, 0.0, -0.08420782, 0.0)
+
+ assert(model2.intercept ~== interceptR2 relTol 1E-3)
+ assert(model2.weights(0) ~== weightsR2(0) absTol 1E-3)
+ assert(model2.weights(1) ~== weightsR2(1) absTol 1E-3)
+ assert(model2.weights(2) ~== weightsR2(2) relTol 1E-2)
+ assert(model2.weights(3) ~== weightsR2(3) absTol 1E-3)
}
test("binary logistic regression with intercept with L2 regularization") {
- val trainer = (new LogisticRegression).setFitIntercept(true)
- .setElasticNetParam(0.0).setRegParam(1.37)
- val model = trainer.fit(binaryDataset)
+ val trainer1 = (new LogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true)
+ val trainer2 = (new LogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false)
+
+ val model1 = trainer1.fit(binaryDataset)
+ val model2 = trainer2.fit(binaryDataset)
/*
Using the following R code to load the data and train the model using glmnet package.
@@ -362,20 +451,52 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
data.V4 -0.04865309
data.V5 -0.10062872
*/
- val interceptR = 0.15021751
- val weightsR = Array(-0.07251837, 0.10724191, -0.04865309, -0.10062872)
-
- assert(model.intercept ~== interceptR relTol 1E-3)
- assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
- assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
- assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
- assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
+ val interceptR1 = 0.15021751
+ val weightsR1 = Array(-0.07251837, 0.10724191, -0.04865309, -0.10062872)
+
+ assert(model1.intercept ~== interceptR1 relTol 1E-3)
+ assert(model1.weights(0) ~== weightsR1(0) relTol 1E-3)
+ assert(model1.weights(1) ~== weightsR1(1) relTol 1E-3)
+ assert(model1.weights(2) ~== weightsR1(2) relTol 1E-3)
+ assert(model1.weights(3) ~== weightsR1(3) relTol 1E-3)
+
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ library("glmnet")
+ data <- read.csv("path", header=FALSE)
+ label = factor(data$V1)
+ features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
+ standardize=FALSE))
+ weights
+
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 0.48657516
+ data.V2 -0.05155371
+ data.V3 0.02301057
+ data.V4 -0.11482896
+ data.V5 -0.06266838
+ */
+ val interceptR2 = 0.48657516
+ val weightsR2 = Array(-0.05155371, 0.02301057, -0.11482896, -0.06266838)
+
+ assert(model2.intercept ~== interceptR2 relTol 1E-3)
+ assert(model2.weights(0) ~== weightsR2(0) relTol 1E-3)
+ assert(model2.weights(1) ~== weightsR2(1) relTol 1E-3)
+ assert(model2.weights(2) ~== weightsR2(2) relTol 1E-3)
+ assert(model2.weights(3) ~== weightsR2(3) relTol 1E-3)
}
test("binary logistic regression without intercept with L2 regularization") {
- val trainer = (new LogisticRegression).setFitIntercept(false)
- .setElasticNetParam(0.0).setRegParam(1.37)
- val model = trainer.fit(binaryDataset)
+ val trainer1 = (new LogisticRegression).setFitIntercept(false)
+ .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true)
+ val trainer2 = (new LogisticRegression).setFitIntercept(false)
+ .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false)
+
+ val model1 = trainer1.fit(binaryDataset)
+ val model2 = trainer2.fit(binaryDataset)
/*
Using the following R code to load the data and train the model using glmnet package.
@@ -396,20 +517,52 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
data.V4 -0.04708770
data.V5 -0.09799775
*/
- val interceptR = 0.0
- val weightsR = Array(-0.06099165, 0.12857058, -0.04708770, -0.09799775)
+ val interceptR1 = 0.0
+ val weightsR1 = Array(-0.06099165, 0.12857058, -0.04708770, -0.09799775)
+
+ assert(model1.intercept ~== interceptR1 relTol 1E-3)
+ assert(model1.weights(0) ~== weightsR1(0) relTol 1E-2)
+ assert(model1.weights(1) ~== weightsR1(1) relTol 1E-2)
+ assert(model1.weights(2) ~== weightsR1(2) relTol 1E-3)
+ assert(model1.weights(3) ~== weightsR1(3) relTol 1E-3)
+
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ library("glmnet")
+ data <- read.csv("path", header=FALSE)
+ label = factor(data$V1)
+ features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
+ intercept=FALSE, standardize=FALSE))
+ weights
- assert(model.intercept ~== interceptR relTol 1E-3)
- assert(model.weights(0) ~== weightsR(0) relTol 1E-2)
- assert(model.weights(1) ~== weightsR(1) relTol 1E-2)
- assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
- assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ data.V2 -0.005679651
+ data.V3 0.048967094
+ data.V4 -0.093714016
+ data.V5 -0.053314311
+ */
+ val interceptR2 = 0.0
+ val weightsR2 = Array(-0.005679651, 0.048967094, -0.093714016, -0.053314311)
+
+ assert(model2.intercept ~== interceptR2 relTol 1E-3)
+ assert(model2.weights(0) ~== weightsR2(0) relTol 1E-2)
+ assert(model2.weights(1) ~== weightsR2(1) relTol 1E-2)
+ assert(model2.weights(2) ~== weightsR2(2) relTol 1E-3)
+ assert(model2.weights(3) ~== weightsR2(3) relTol 1E-3)
}
test("binary logistic regression with intercept with ElasticNet regularization") {
- val trainer = (new LogisticRegression).setFitIntercept(true)
- .setElasticNetParam(0.38).setRegParam(0.21)
- val model = trainer.fit(binaryDataset)
+ val trainer1 = (new LogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+ val trainer2 = (new LogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+
+ val model1 = trainer1.fit(binaryDataset)
+ val model2 = trainer2.fit(binaryDataset)
/*
Using the following R code to load the data and train the model using glmnet package.
@@ -429,20 +582,52 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
data.V4 -0.08849250
data.V5 -0.15458796
*/
- val interceptR = 0.57734851
- val weightsR = Array(-0.05310287, 0.0, -0.08849250, -0.15458796)
-
- assert(model.intercept ~== interceptR relTol 6E-3)
- assert(model.weights(0) ~== weightsR(0) relTol 5E-3)
- assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
- assert(model.weights(2) ~== weightsR(2) relTol 5E-3)
- assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
+ val interceptR1 = 0.57734851
+ val weightsR1 = Array(-0.05310287, 0.0, -0.08849250, -0.15458796)
+
+ assert(model1.intercept ~== interceptR1 relTol 6E-3)
+ assert(model1.weights(0) ~== weightsR1(0) relTol 5E-3)
+ assert(model1.weights(1) ~== weightsR1(1) absTol 1E-3)
+ assert(model1.weights(2) ~== weightsR1(2) relTol 5E-3)
+ assert(model1.weights(3) ~== weightsR1(3) relTol 1E-3)
+
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ library("glmnet")
+ data <- read.csv("path", header=FALSE)
+ label = factor(data$V1)
+ features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
+ standardize=FALSE))
+ weights
+
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 0.51555993
+ data.V2 .
+ data.V3 .
+ data.V4 -0.18807395
+ data.V5 -0.05350074
+ */
+ val interceptR2 = 0.51555993
+ val weightsR2 = Array(0.0, 0.0, -0.18807395, -0.05350074)
+
+ assert(model2.intercept ~== interceptR2 relTol 6E-3)
+ assert(model2.weights(0) ~== weightsR2(0) absTol 1E-3)
+ assert(model2.weights(1) ~== weightsR2(1) absTol 1E-3)
+ assert(model2.weights(2) ~== weightsR2(2) relTol 5E-3)
+ assert(model2.weights(3) ~== weightsR2(3) relTol 1E-2)
}
test("binary logistic regression without intercept with ElasticNet regularization") {
- val trainer = (new LogisticRegression).setFitIntercept(false)
- .setElasticNetParam(0.38).setRegParam(0.21)
- val model = trainer.fit(binaryDataset)
+ val trainer1 = (new LogisticRegression).setFitIntercept(false)
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+ val trainer2 = (new LogisticRegression).setFitIntercept(false)
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+
+ val model1 = trainer1.fit(binaryDataset)
+ val model2 = trainer2.fit(binaryDataset)
/*
Using the following R code to load the data and train the model using glmnet package.
@@ -463,20 +648,52 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
data.V4 -0.081203769
data.V5 -0.142534158
*/
- val interceptR = 0.0
- val weightsR = Array(-0.001005743, 0.072577857, -0.081203769, -0.142534158)
+ val interceptR1 = 0.0
+ val weightsR1 = Array(-0.001005743, 0.072577857, -0.081203769, -0.142534158)
+
+ assert(model1.intercept ~== interceptR1 relTol 1E-3)
+ assert(model1.weights(0) ~== weightsR1(0) absTol 1E-2)
+ assert(model1.weights(1) ~== weightsR1(1) absTol 1E-2)
+ assert(model1.weights(2) ~== weightsR1(2) relTol 1E-3)
+ assert(model1.weights(3) ~== weightsR1(3) relTol 1E-2)
+
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ library("glmnet")
+ data <- read.csv("path", header=FALSE)
+ label = factor(data$V1)
+ features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
+ intercept=FALSE, standardize=FALSE))
+ weights
- assert(model.intercept ~== interceptR relTol 1E-3)
- assert(model.weights(0) ~== weightsR(0) absTol 1E-3)
- assert(model.weights(1) ~== weightsR(1) absTol 1E-2)
- assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
- assert(model.weights(3) ~== weightsR(3) relTol 1E-2)
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ data.V2 .
+ data.V3 0.03345223
+ data.V4 -0.11304532
+ data.V5 .
+ */
+ val interceptR2 = 0.0
+ val weightsR2 = Array(0.0, 0.03345223, -0.11304532, 0.0)
+
+ assert(model2.intercept ~== interceptR2 relTol 1E-3)
+ assert(model2.weights(0) ~== weightsR2(0) absTol 1E-3)
+ assert(model2.weights(1) ~== weightsR2(1) relTol 1E-2)
+ assert(model2.weights(2) ~== weightsR2(2) relTol 1E-2)
+ assert(model2.weights(3) ~== weightsR2(3) absTol 1E-3)
}
test("binary logistic regression with intercept with strong L1 regularization") {
- val trainer = (new LogisticRegression).setFitIntercept(true)
- .setElasticNetParam(1.0).setRegParam(6.0)
- val model = trainer.fit(binaryDataset)
+ val trainer1 = (new LogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(true)
+ val trainer2 = (new LogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(false)
+
+ val model1 = trainer1.fit(binaryDataset)
+ val model2 = trainer2.fit(binaryDataset)
val histogram = binaryDataset.map { case Row(label: Double, features: Vector) => label }
.treeAggregate(new MultiClassSummarizer)(
@@ -502,11 +719,17 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble)
val weightsTheory = Array(0.0, 0.0, 0.0, 0.0)
- assert(model.intercept ~== interceptTheory relTol 1E-5)
- assert(model.weights(0) ~== weightsTheory(0) absTol 1E-6)
- assert(model.weights(1) ~== weightsTheory(1) absTol 1E-6)
- assert(model.weights(2) ~== weightsTheory(2) absTol 1E-6)
- assert(model.weights(3) ~== weightsTheory(3) absTol 1E-6)
+ assert(model1.intercept ~== interceptTheory relTol 1E-5)
+ assert(model1.weights(0) ~== weightsTheory(0) absTol 1E-6)
+ assert(model1.weights(1) ~== weightsTheory(1) absTol 1E-6)
+ assert(model1.weights(2) ~== weightsTheory(2) absTol 1E-6)
+ assert(model1.weights(3) ~== weightsTheory(3) absTol 1E-6)
+
+ assert(model2.intercept ~== interceptTheory relTol 1E-5)
+ assert(model2.weights(0) ~== weightsTheory(0) absTol 1E-6)
+ assert(model2.weights(1) ~== weightsTheory(1) absTol 1E-6)
+ assert(model2.weights(2) ~== weightsTheory(2) absTol 1E-6)
+ assert(model2.weights(3) ~== weightsTheory(3) absTol 1E-6)
/*
Using the following R code to load the data and train the model using glmnet package.
@@ -529,10 +752,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val interceptR = -0.248065
val weightsR = Array(0.0, 0.0, 0.0, 0.0)
- assert(model.intercept ~== interceptR relTol 1E-5)
- assert(model.weights(0) ~== weightsR(0) absTol 1E-6)
- assert(model.weights(1) ~== weightsR(1) absTol 1E-6)
- assert(model.weights(2) ~== weightsR(2) absTol 1E-6)
- assert(model.weights(3) ~== weightsR(3) absTol 1E-6)
+ assert(model1.intercept ~== interceptR relTol 1E-5)
+ assert(model1.weights(0) ~== weightsR(0) absTol 1E-6)
+ assert(model1.weights(1) ~== weightsR(1) absTol 1E-6)
+ assert(model1.weights(2) ~== weightsR(2) absTol 1E-6)
+ assert(model1.weights(3) ~== weightsR(3) absTol 1E-6)
}
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 680b699e9e..41e19fd9cc 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -58,6 +58,8 @@ object MimaExcludes {
"org.apache.spark.ml.regression.LeastSquaresAggregator.this"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.ml.regression.LeastSquaresCostFun.this"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.ml.classification.LogisticCostFun.this"),
// SQL execution is considered private.
excludePackage("org.apache.spark.sql.execution"),
// NanoTime and CatalystTimestampConverter is only used inside catalyst,