aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorlewuathe <lewuathe@me.com>2015-10-19 10:46:10 -0700
committerDB Tsai <dbt@netflix.com>2015-10-19 10:46:10 -0700
commit4c33a34ba3167ae67fdb4978ea2166ce65638fb9 (patch)
tree5a5dbae89a230ad0c82acab25ba98e9121b1af6b /mllib
parentdfa41e63b98c28b087c56f94658b5e99e8a7758c (diff)
downloadspark-4c33a34ba3167ae67fdb4978ea2166ce65638fb9.tar.gz
spark-4c33a34ba3167ae67fdb4978ea2166ce65638fb9.tar.bz2
spark-4c33a34ba3167ae67fdb4978ea2166ce65638fb9.zip
[SPARK-10668] [ML] Use WeightedLeastSquares in LinearRegression with L…
…2 regularization if the number of features is small Author: lewuathe <lewuathe@me.com> Author: Lewuathe <sasaki@treasure-data.com> Author: Kai Sasaki <sasaki@treasure-data.com> Author: Lewuathe <lewuathe@me.com> Closes #8884 from Lewuathe/SPARK-10668.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala50
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala1045
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala2
8 files changed, 636 insertions, 491 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 8cb6b5493c..c7bca12430 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -73,7 +73,9 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"),
ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."),
ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " +
- "all instance weights as 1.0."))
+ "all instance weights as 1.0."),
+ ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " +
+ "empty, default value is 'auto'.", Some("\"auto\"")))
val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index e3625212e5..cb2a060a34 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -357,4 +357,21 @@ private[ml] trait HasWeightCol extends Params {
/** @group getParam */
final def getWeightCol: String = $(weightCol)
}
+
+/**
+ * Trait for shared param solver (default: "auto").
+ */
+private[ml] trait HasSolver extends Params {
+
+ /**
+ * Param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'..
+ * @group param
+ */
+ final val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.")
+
+ setDefault(solver, "auto")
+
+ /** @group getParam */
+ final def getSolver: String = $(solver)
+}
// scalastyle:on
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index f5a022c31e..fec61fed3c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -30,13 +30,15 @@ private[r] object SparkRWrappers {
df: DataFrame,
family: String,
lambda: Double,
- alpha: Double): PipelineModel = {
+ alpha: Double,
+ solver: String): PipelineModel = {
val formula = new RFormula().setFormula(value)
val estimator = family match {
case "gaussian" => new LinearRegression()
.setRegParam(lambda)
.setElasticNetParam(alpha)
.setFitIntercept(formula.hasIntercept)
+ .setSolver(solver)
case "binomial" => new LogisticRegression()
.setRegParam(lambda)
.setElasticNetParam(alpha)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index dd09667ef5..573a61a6ea 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -25,6 +25,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.optim.WeightedLeastSquares
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
@@ -43,7 +44,7 @@ import org.apache.spark.storage.StorageLevel
*/
private[regression] trait LinearRegressionParams extends PredictorParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
- with HasFitIntercept with HasStandardization with HasWeightCol
+ with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
/**
* :: Experimental ::
@@ -130,9 +131,53 @@ class LinearRegression(override val uid: String)
def setWeightCol(value: String): this.type = set(weightCol, value)
setDefault(weightCol -> "")
+ /**
+ * Set the solver algorithm used for optimization.
+ * In case of linear regression, this can be "l-bfgs", "normal" and "auto".
+ * The default value is "auto" which means that the solver algorithm is
+ * selected automatically.
+ * @group setParam
+ */
+ def setSolver(value: String): this.type = set(solver, value)
+ setDefault(solver -> "auto")
+
override protected def train(dataset: DataFrame): LinearRegressionModel = {
- // Extract columns from data. If dataset is persisted, do not persist instances.
+ // Extract the number of features before deciding optimization solver.
+ val numFeatures = dataset.select(col($(featuresCol))).limit(1).map {
+ case Row(features: Vector) => features.size
+ }.toArray()(0)
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
+
+ if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= 4096) ||
+ $(solver) == "normal") {
+ require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " +
+ "solver is used.'")
+ // For low dimensional data, WeightedLeastSquares is more efficiently since the
+ // training algorithm only requires one pass through the data. (SPARK-10668)
+ val instances: RDD[WeightedLeastSquares.Instance] = dataset.select(
+ col($(labelCol)), w, col($(featuresCol))).map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ WeightedLeastSquares.Instance(weight, features, label)
+ }
+
+ val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
+ $(standardization), true)
+ val model = optimizer.fit(instances)
+ // When it is trained by WeightedLeastSquares, training summary does not
+ // attached returned model.
+ val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept))
+ // WeightedLeastSquares does not run through iterations. So it does not generate
+ // an objective history.
+ val (summaryModel, predictionColName) = lrModel.findSummaryModelAndPredictionCol()
+ val trainingSummary = new LinearRegressionTrainingSummary(
+ summaryModel.transform(dataset),
+ predictionColName,
+ $(labelCol),
+ $(featuresCol),
+ Array(0D))
+ return lrModel.setSummary(trainingSummary)
+ }
+
val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
@@ -155,7 +200,6 @@ class LinearRegression(override val uid: String)
new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer)(seqOp, combOp)
}
- val numFeatures = featuresSummarizer.mean.size
val yMean = ySummarizer.mean(0)
val yStd = math.sqrt(ySummarizer.variance(0))
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 91c589d00a..4fb0b0d109 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -61,6 +61,7 @@ public class JavaLinearRegressionSuite implements Serializable {
public void linearRegressionDefaultParams() {
LinearRegression lr = new LinearRegression();
assertEquals("label", lr.getLabelCol());
+ assertEquals("auto", lr.getSolver());
LinearRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction");
@@ -75,7 +76,7 @@ public class JavaLinearRegressionSuite implements Serializable {
// Set params, train, and check as many params as we can.
LinearRegression lr = new LinearRegression()
.setMaxIter(10)
- .setRegParam(1.0);
+ .setRegParam(1.0).setSolver("l-bfgs");
LinearRegressionModel model = lr.fit(dataset);
LinearRegression parent = (LinearRegression) model.parent();
assertEquals(10, parent.getMaxIter());
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 73a0a5caf8..a6e0c72ba9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.{DataFrame, Row}
class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+ private val seed: Int = 42
@transient var dataset: DataFrame = _
@transient var datasetWithoutIntercept: DataFrame = _
@@ -50,15 +51,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
super.beforeAll()
dataset = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
- 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
+ 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, seed, 0.1), 2))
/*
datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
training model without intercept
*/
datasetWithoutIntercept = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
- 0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
-
+ 0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, seed, 0.1), 2))
}
test("params") {
@@ -76,6 +76,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(lir.getElasticNetParam === 0.0)
assert(lir.getFitIntercept)
assert(lir.getStandardization)
+ assert(lir.getSolver == "auto")
val model = lir.fit(dataset)
// copied model must have the same parent.
@@ -93,525 +94,603 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("linear regression with intercept without regularization") {
- val trainer1 = new LinearRegression
- // The result should be the same regardless of standardization without regularization
- val trainer2 = (new LinearRegression).setStandardization(false)
- val model1 = trainer1.fit(dataset)
- val model2 = trainer2.fit(dataset)
-
- /*
- Using the following R code to load the data and train the model using glmnet package.
-
- library("glmnet")
- data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
- features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
- label <- as.numeric(data$V1)
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) 6.298698
- as.numeric.data.V2. 4.700706
- as.numeric.data.V3. 7.199082
- */
- val interceptR = 6.298698
- val weightsR = Vectors.dense(4.700706, 7.199082)
-
- assert(model1.intercept ~== interceptR relTol 1E-3)
- assert(model1.weights ~= weightsR relTol 1E-3)
- assert(model2.intercept ~== interceptR relTol 1E-3)
- assert(model2.weights ~= weightsR relTol 1E-3)
-
-
- model1.transform(dataset).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val prediction2 =
- features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
- assert(prediction1 ~== prediction2 relTol 1E-5)
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ val trainer1 = new LinearRegression().setSolver(solver)
+ // The result should be the same regardless of standardization without regularization
+ val trainer2 = (new LinearRegression).setStandardization(false).setSolver(solver)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
+
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ library("glmnet")
+ data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+ features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
+ label <- as.numeric(data$V1)
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 6.298698
+ as.numeric.data.V2. 4.700706
+ as.numeric.data.V3. 7.199082
+ */
+ val interceptR = 6.298698
+ val weightsR = Vectors.dense(4.700706, 7.199082)
+
+ assert(model1.intercept ~== interceptR relTol 1E-3)
+ assert(model1.weights ~= weightsR relTol 1E-3)
+ assert(model2.intercept ~== interceptR relTol 1E-3)
+ assert(model2.weights ~= weightsR relTol 1E-3)
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 =
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
}
}
test("linear regression without intercept without regularization") {
- val trainer1 = (new LinearRegression).setFitIntercept(false)
- // Without regularization the results should be the same
- val trainer2 = (new LinearRegression).setFitIntercept(false).setStandardization(false)
- val model1 = trainer1.fit(dataset)
- val modelWithoutIntercept1 = trainer1.fit(datasetWithoutIntercept)
- val model2 = trainer2.fit(dataset)
- val modelWithoutIntercept2 = trainer2.fit(datasetWithoutIntercept)
-
-
- /*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
- intercept = FALSE))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) .
- as.numeric.data.V2. 6.995908
- as.numeric.data.V3. 5.275131
- */
- val weightsR = Vectors.dense(6.995908, 5.275131)
-
- assert(model1.intercept ~== 0 absTol 1E-3)
- assert(model1.weights ~= weightsR relTol 1E-3)
- assert(model2.intercept ~== 0 absTol 1E-3)
- assert(model2.weights ~= weightsR relTol 1E-3)
-
- /*
- Then again with the data with no intercept:
- > weightsWithoutIntercept
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) .
- as.numeric.data3.V2. 4.70011
- as.numeric.data3.V3. 7.19943
- */
- val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943)
-
- assert(modelWithoutIntercept1.intercept ~== 0 absTol 1E-3)
- assert(modelWithoutIntercept1.weights ~= weightsWithoutInterceptR relTol 1E-3)
- assert(modelWithoutIntercept2.intercept ~== 0 absTol 1E-3)
- assert(modelWithoutIntercept2.weights ~= weightsWithoutInterceptR relTol 1E-3)
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ val trainer1 = (new LinearRegression).setFitIntercept(false).setSolver(solver)
+ // Without regularization the results should be the same
+ val trainer2 = (new LinearRegression).setFitIntercept(false).setStandardization(false)
+ .setSolver(solver)
+ val model1 = trainer1.fit(dataset)
+ val modelWithoutIntercept1 = trainer1.fit(datasetWithoutIntercept)
+ val model2 = trainer2.fit(dataset)
+ val modelWithoutIntercept2 = trainer2.fit(datasetWithoutIntercept)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
+ intercept = FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 6.995908
+ as.numeric.data.V3. 5.275131
+ */
+ val weightsR = Vectors.dense(6.995908, 5.275131)
+
+ assert(model1.intercept ~== 0 absTol 1E-3)
+ assert(model1.weights ~= weightsR relTol 1E-3)
+ assert(model2.intercept ~== 0 absTol 1E-3)
+ assert(model2.weights ~= weightsR relTol 1E-3)
+
+ /*
+ Then again with the data with no intercept:
+ > weightsWithoutIntercept
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data3.V2. 4.70011
+ as.numeric.data3.V3. 7.19943
+ */
+ val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943)
+
+ assert(modelWithoutIntercept1.intercept ~== 0 absTol 1E-3)
+ assert(modelWithoutIntercept1.weights ~= weightsWithoutInterceptR relTol 1E-3)
+ assert(modelWithoutIntercept2.intercept ~== 0 absTol 1E-3)
+ assert(modelWithoutIntercept2.weights ~= weightsWithoutInterceptR relTol 1E-3)
+ }
}
test("linear regression with intercept with L1 regularization") {
- val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
- val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
- .setStandardization(false)
- val model1 = trainer1.fit(dataset)
- val model2 = trainer2.fit(dataset)
-
- /*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) 6.24300
- as.numeric.data.V2. 4.024821
- as.numeric.data.V3. 6.679841
- */
- val interceptR1 = 6.24300
- val weightsR1 = Vectors.dense(4.024821, 6.679841)
-
- assert(model1.intercept ~== interceptR1 relTol 1E-3)
- assert(model1.weights ~= weightsR1 relTol 1E-3)
-
- /*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
- standardize=FALSE))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) 6.416948
- as.numeric.data.V2. 3.893869
- as.numeric.data.V3. 6.724286
- */
- val interceptR2 = 6.416948
- val weightsR2 = Vectors.dense(3.893869, 6.724286)
-
- assert(model2.intercept ~== interceptR2 relTol 1E-3)
- assert(model2.weights ~= weightsR2 relTol 1E-3)
-
-
- model1.transform(dataset).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val prediction2 =
- features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
- assert(prediction1 ~== prediction2 relTol 1E-5)
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+ .setSolver(solver)
+ val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+ .setSolver(solver).setStandardization(false)
+
+ var model1: LinearRegressionModel = null
+ var model2: LinearRegressionModel = null
+
+ // Normal optimizer is not supported with only L1 regularization case.
+ if (solver == "normal") {
+ intercept[IllegalArgumentException] {
+ trainer1.fit(dataset)
+ trainer2.fit(dataset)
+ }
+ } else {
+ model1 = trainer1.fit(dataset)
+ model2 = trainer2.fit(dataset)
+
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 6.24300
+ as.numeric.data.V2. 4.024821
+ as.numeric.data.V3. 6.679841
+ */
+ val interceptR1 = 6.24300
+ val weightsR1 = Vectors.dense(4.024821, 6.679841)
+ assert(model1.intercept ~== interceptR1 relTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
+ standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 6.416948
+ as.numeric.data.V2. 3.893869
+ as.numeric.data.V3. 6.724286
+ */
+ val interceptR2 = 6.416948
+ val weightsR2 = Vectors.dense(3.893869, 6.724286)
+
+ assert(model2.intercept ~== interceptR2 relTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 =
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
}
}
test("linear regression without intercept with L1 regularization") {
- val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
- .setFitIntercept(false)
- val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
- .setFitIntercept(false).setStandardization(false)
- val model1 = trainer1.fit(dataset)
- val model2 = trainer2.fit(dataset)
-
- /*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
- intercept=FALSE))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) .
- as.numeric.data.V2. 6.299752
- as.numeric.data.V3. 4.772913
- */
- val interceptR1 = 0.0
- val weightsR1 = Vectors.dense(6.299752, 4.772913)
-
- assert(model1.intercept ~== interceptR1 absTol 1E-3)
- assert(model1.weights ~= weightsR1 relTol 1E-3)
-
- /*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
- intercept=FALSE, standardize=FALSE))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) .
- as.numeric.data.V2. 6.232193
- as.numeric.data.V3. 4.764229
- */
- val interceptR2 = 0.0
- val weightsR2 = Vectors.dense(6.232193, 4.764229)
-
- assert(model2.intercept ~== interceptR2 absTol 1E-3)
- assert(model2.weights ~= weightsR2 relTol 1E-3)
-
-
- model1.transform(dataset).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val prediction2 =
- features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
- assert(prediction1 ~== prediction2 relTol 1E-5)
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+ .setFitIntercept(false).setSolver(solver)
+ val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+ .setFitIntercept(false).setStandardization(false).setSolver(solver)
+
+ var model1: LinearRegressionModel = null
+ var model2: LinearRegressionModel = null
+
+ // Normal optimizer is not supported with only L1 regularization case.
+ if (solver == "normal") {
+ intercept[IllegalArgumentException] {
+ trainer1.fit(dataset)
+ trainer2.fit(dataset)
+ }
+ } else {
+ model1 = trainer1.fit(dataset)
+ model2 = trainer2.fit(dataset)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
+ intercept=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 6.299752
+ as.numeric.data.V3. 4.772913
+ */
+ val interceptR1 = 0.0
+ val weightsR1 = Vectors.dense(6.299752, 4.772913)
+
+ assert(model1.intercept ~== interceptR1 absTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
+ intercept=FALSE, standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 6.232193
+ as.numeric.data.V3. 4.764229
+ */
+ val interceptR2 = 0.0
+ val weightsR2 = Vectors.dense(6.232193, 4.764229)
+
+ assert(model2.intercept ~== interceptR2 absTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 =
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
}
}
test("linear regression with intercept with L2 regularization") {
- val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
- val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
- .setStandardization(false)
- val model1 = trainer1.fit(dataset)
- val model2 = trainer2.fit(dataset)
-
- /*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) 5.269376
- as.numeric.data.V2. 3.736216
- as.numeric.data.V3. 5.712356)
- */
- val interceptR1 = 5.269376
- val weightsR1 = Vectors.dense(3.736216, 5.712356)
-
- assert(model1.intercept ~== interceptR1 relTol 1E-3)
- assert(model1.weights ~= weightsR1 relTol 1E-3)
-
- /*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
- standardize=FALSE))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) 5.791109
- as.numeric.data.V2. 3.435466
- as.numeric.data.V3. 5.910406
- */
- val interceptR2 = 5.791109
- val weightsR2 = Vectors.dense(3.435466, 5.910406)
-
- assert(model2.intercept ~== interceptR2 relTol 1E-3)
- assert(model2.weights ~= weightsR2 relTol 1E-3)
-
- model1.transform(dataset).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val prediction2 =
- features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
- assert(prediction1 ~== prediction2 relTol 1E-5)
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+ .setSolver(solver)
+ val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+ .setStandardization(false).setSolver(solver)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 5.269376
+ as.numeric.data.V2. 3.736216
+ as.numeric.data.V3. 5.712356)
+ */
+ val interceptR1 = 5.269376
+ val weightsR1 = Vectors.dense(3.736216, 5.712356)
+
+ assert(model1.intercept ~== interceptR1 relTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
+ standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 5.791109
+ as.numeric.data.V2. 3.435466
+ as.numeric.data.V3. 5.910406
+ */
+ val interceptR2 = 5.791109
+ val weightsR2 = Vectors.dense(3.435466, 5.910406)
+
+ assert(model2.intercept ~== interceptR2 relTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 =
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
}
}
test("linear regression without intercept with L2 regularization") {
- val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
- .setFitIntercept(false)
- val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
- .setFitIntercept(false).setStandardization(false)
- val model1 = trainer1.fit(dataset)
- val model2 = trainer2.fit(dataset)
-
- /*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
- intercept = FALSE))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) .
- as.numeric.data.V2. 5.522875
- as.numeric.data.V3. 4.214502
- */
- val interceptR1 = 0.0
- val weightsR1 = Vectors.dense(5.522875, 4.214502)
-
- assert(model1.intercept ~== interceptR1 absTol 1E-3)
- assert(model1.weights ~= weightsR1 relTol 1E-3)
-
- /*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
- intercept = FALSE, standardize=FALSE))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) .
- as.numeric.data.V2. 5.263704
- as.numeric.data.V3. 4.187419
- */
- val interceptR2 = 0.0
- val weightsR2 = Vectors.dense(5.263704, 4.187419)
-
- assert(model2.intercept ~== interceptR2 absTol 1E-3)
- assert(model2.weights ~= weightsR2 relTol 1E-3)
-
- model1.transform(dataset).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val prediction2 =
- features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
- assert(prediction1 ~== prediction2 relTol 1E-5)
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+ .setFitIntercept(false).setSolver(solver)
+ val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+ .setFitIntercept(false).setStandardization(false).setSolver(solver)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
+ intercept = FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 5.522875
+ as.numeric.data.V3. 4.214502
+ */
+ val interceptR1 = 0.0
+ val weightsR1 = Vectors.dense(5.522875, 4.214502)
+
+ assert(model1.intercept ~== interceptR1 absTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
+ intercept = FALSE, standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 5.263704
+ as.numeric.data.V3. 4.187419
+ */
+ val interceptR2 = 0.0
+ val weightsR2 = Vectors.dense(5.263704, 4.187419)
+
+ assert(model2.intercept ~== interceptR2 absTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 =
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
}
}
test("linear regression with intercept with ElasticNet regularization") {
- val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
- val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
- .setStandardization(false)
- val model1 = trainer1.fit(dataset)
- val model2 = trainer2.fit(dataset)
-
- /*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) 6.324108
- as.numeric.data.V2. 3.168435
- as.numeric.data.V3. 5.200403
- */
- val interceptR1 = 5.696056
- val weightsR1 = Vectors.dense(3.670489, 6.001122)
-
- assert(model1.intercept ~== interceptR1 relTol 1E-3)
- assert(model1.weights ~= weightsR1 relTol 1E-3)
-
- /*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6
- standardize=FALSE))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) 6.114723
- as.numeric.data.V2. 3.409937
- as.numeric.data.V3. 6.146531
- */
- val interceptR2 = 6.114723
- val weightsR2 = Vectors.dense(3.409937, 6.146531)
-
- assert(model2.intercept ~== interceptR2 relTol 1E-3)
- assert(model2.weights ~= weightsR2 relTol 1E-3)
-
- model1.transform(dataset).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val prediction2 =
- features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
- assert(prediction1 ~== prediction2 relTol 1E-5)
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+ .setSolver(solver)
+ val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+ .setStandardization(false).setSolver(solver)
+
+ var model1: LinearRegressionModel = null
+ var model2: LinearRegressionModel = null
+
+ // Normal optimizer is not supported with non-zero elasticnet parameter.
+ if (solver == "normal") {
+ intercept[IllegalArgumentException] {
+ trainer1.fit(dataset)
+ trainer2.fit(dataset)
+ }
+ } else {
+ model1 = trainer1.fit(dataset)
+ model2 = trainer2.fit(dataset)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 6.324108
+ as.numeric.data.V2. 3.168435
+ as.numeric.data.V3. 5.200403
+ */
+ val interceptR1 = 5.696056
+ val weightsR1 = Vectors.dense(3.670489, 6.001122)
+
+ assert(model1.intercept ~== interceptR1 relTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6
+ standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 6.114723
+ as.numeric.data.V2. 3.409937
+ as.numeric.data.V3. 6.146531
+ */
+ val interceptR2 = 6.114723
+ val weightsR2 = Vectors.dense(3.409937, 6.146531)
+
+ assert(model2.intercept ~== interceptR2 relTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 =
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
}
}
test("linear regression without intercept with ElasticNet regularization") {
- val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
- .setFitIntercept(false)
- val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
- .setFitIntercept(false).setStandardization(false)
- val model1 = trainer1.fit(dataset)
- val model2 = trainer2.fit(dataset)
-
- /*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
- intercept=FALSE))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) .
- as.numeric.dataM.V2. 5.673348
- as.numeric.dataM.V3. 4.322251
- */
- val interceptR1 = 0.0
- val weightsR1 = Vectors.dense(5.673348, 4.322251)
-
- assert(model1.intercept ~== interceptR1 absTol 1E-3)
- assert(model1.weights ~= weightsR1 relTol 1E-3)
-
- /*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
- intercept=FALSE, standardize=FALSE))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) .
- as.numeric.data.V2. 5.477988
- as.numeric.data.V3. 4.297622
- */
- val interceptR2 = 0.0
- val weightsR2 = Vectors.dense(5.477988, 4.297622)
-
- assert(model2.intercept ~== interceptR2 absTol 1E-3)
- assert(model2.weights ~= weightsR2 relTol 1E-3)
-
- model1.transform(dataset).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val prediction2 =
- features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
- assert(prediction1 ~== prediction2 relTol 1E-5)
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+ .setFitIntercept(false).setSolver(solver)
+ val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+ .setFitIntercept(false).setStandardization(false).setSolver(solver)
+
+ var model1: LinearRegressionModel = null
+ var model2: LinearRegressionModel = null
+
+ // Normal optimizer is not supported with non-zero elasticnet parameter.
+ if (solver == "normal") {
+ intercept[IllegalArgumentException] {
+ trainer1.fit(dataset)
+ trainer2.fit(dataset)
+ }
+ } else {
+ model1 = trainer1.fit(dataset)
+ model2 = trainer2.fit(dataset)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
+ intercept=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.dataM.V2. 5.673348
+ as.numeric.dataM.V3. 4.322251
+ */
+ val interceptR1 = 0.0
+ val weightsR1 = Vectors.dense(5.673348, 4.322251)
+
+ assert(model1.intercept ~== interceptR1 absTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
+ intercept=FALSE, standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 5.477988
+ as.numeric.data.V3. 4.297622
+ */
+ val interceptR2 = 0.0
+ val weightsR2 = Vectors.dense(5.477988, 4.297622)
+
+ assert(model2.intercept ~== interceptR2 absTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 =
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
}
}
test("linear regression model training summary") {
- val trainer = new LinearRegression
- val model = trainer.fit(dataset)
- val trainerNoPredictionCol = trainer.setPredictionCol("")
- val modelNoPredictionCol = trainerNoPredictionCol.fit(dataset)
-
-
- // Training results for the model should be available
- assert(model.hasSummary)
- assert(modelNoPredictionCol.hasSummary)
-
- // Schema should be a superset of the input dataset
- assert((dataset.schema.fieldNames.toSet + "prediction").subsetOf(
- model.summary.predictions.schema.fieldNames.toSet))
- // Validate that we re-insert a prediction column for evaluation
- val modelNoPredictionColFieldNames = modelNoPredictionCol.summary.predictions.schema.fieldNames
- assert((dataset.schema.fieldNames.toSet).subsetOf(
- modelNoPredictionColFieldNames.toSet))
- assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_")))
-
- // Residuals in [[LinearRegressionResults]] should equal those manually computed
- val expectedResiduals = dataset.select("features", "label")
- .map { case Row(features: DenseVector, label: Double) =>
- val prediction =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
- label - prediction
- }
- .zip(model.summary.residuals.map(_.getDouble(0)))
- .collect()
- .foreach { case (manualResidual: Double, resultResidual: Double) =>
- assert(manualResidual ~== resultResidual relTol 1E-5)
- }
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ val trainer = new LinearRegression().setSolver(solver)
+ val model = trainer.fit(dataset)
+ val trainerNoPredictionCol = trainer.setPredictionCol("")
+ val modelNoPredictionCol = trainerNoPredictionCol.fit(dataset)
+
+
+ // Training results for the model should be available
+ assert(model.hasSummary)
+ assert(modelNoPredictionCol.hasSummary)
+
+ // Schema should be a superset of the input dataset
+ assert((dataset.schema.fieldNames.toSet + "prediction").subsetOf(
+ model.summary.predictions.schema.fieldNames.toSet))
+ // Validate that we re-insert a prediction column for evaluation
+ val modelNoPredictionColFieldNames
+ = modelNoPredictionCol.summary.predictions.schema.fieldNames
+ assert((dataset.schema.fieldNames.toSet).subsetOf(
+ modelNoPredictionColFieldNames.toSet))
+ assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_")))
+
+ // Residuals in [[LinearRegressionResults]] should equal those manually computed
+ val expectedResiduals = dataset.select("features", "label")
+ .map { case Row(features: DenseVector, label: Double) =>
+ val prediction =
+ features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ label - prediction
+ }
+ .zip(model.summary.residuals.map(_.getDouble(0)))
+ .collect()
+ .foreach { case (manualResidual: Double, resultResidual: Double) =>
+ assert(manualResidual ~== resultResidual relTol 1E-5)
+ }
- /*
- Use the following R code to generate model training results.
-
- predictions <- predict(fit, newx=features)
- residuals <- label - predictions
- > mean(residuals^2) # MSE
- [1] 0.009720325
- > mean(abs(residuals)) # MAD
- [1] 0.07863206
- > cor(predictions, label)^2# r^2
- [,1]
- s0 0.9998749
- */
- assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5)
- assert(model.summary.meanAbsoluteError ~== 0.07863206 relTol 1E-5)
- assert(model.summary.r2 ~== 0.9998749 relTol 1E-5)
-
- // Objective function should be monotonically decreasing for linear regression
- assert(
- model.summary
- .objectiveHistory
- .sliding(2)
- .forall(x => x(0) >= x(1)))
+ /*
+ Use the following R code to generate model training results.
+
+ predictions <- predict(fit, newx=features)
+ residuals <- label - predictions
+ > mean(residuals^2) # MSE
+ [1] 0.009720325
+ > mean(abs(residuals)) # MAD
+ [1] 0.07863206
+ > cor(predictions, label)^2# r^2
+ [,1]
+ s0 0.9998749
+ */
+ assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5)
+ assert(model.summary.meanAbsoluteError ~== 0.07863206 relTol 1E-5)
+ assert(model.summary.r2 ~== 0.9998749 relTol 1E-5)
+
+ // Normal solver uses "WeightedLeastSquares". This algorithm does not generate
+ // objective history because it does not run through iterations.
+ if (solver == "l-bfgs") {
+ // Objective function should be monotonically decreasing for linear regression
+ assert(
+ model.summary
+ .objectiveHistory
+ .sliding(2)
+ .forall(x => x(0) >= x(1)))
+ }
+ }
}
test("linear regression model testset evaluation summary") {
- val trainer = new LinearRegression
- val model = trainer.fit(dataset)
-
- // Evaluating on training dataset should yield results summary equal to training summary
- val testSummary = model.evaluate(dataset)
- assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5)
- assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5)
- model.summary.residuals.select("residuals").collect()
- .zip(testSummary.residuals.select("residuals").collect())
- .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 }
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ val trainer = new LinearRegression().setSolver(solver)
+ val model = trainer.fit(dataset)
+
+ // Evaluating on training dataset should yield results summary equal to training summary
+ val testSummary = model.evaluate(dataset)
+ assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5)
+ assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5)
+ model.summary.residuals.select("residuals").collect()
+ .zip(testSummary.residuals.select("residuals").collect())
+ .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 }
+ }
}
- test("linear regression with weighted samples"){
- val (data, weightedData) = {
- val activeData = LinearDataGenerator.generateLinearInput(
- 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1)
-
- val rnd = new Random(8392)
- val signedData = activeData.map { case p: LabeledPoint =>
- (rnd.nextGaussian() > 0.0, p)
- }
-
- val data1 = signedData.flatMap {
- case (true, p) => Iterator(p, p)
- case (false, p) => Iterator(p)
- }
-
- val weightedSignedData = signedData.flatMap {
- case (true, LabeledPoint(label, features)) =>
- Iterator(
- Instance(label, weight = 1.2, features),
- Instance(label, weight = 0.8, features)
- )
- case (false, LabeledPoint(label, features)) =>
- Iterator(
- Instance(label, weight = 0.3, features),
- Instance(label, weight = 0.1, features),
- Instance(label, weight = 0.6, features)
- )
+ test("linear regression with weighted samples") {
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ val (data, weightedData) = {
+ val activeData = LinearDataGenerator.generateLinearInput(
+ 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1)
+
+ val rnd = new Random(8392)
+ val signedData = activeData.map { case p: LabeledPoint =>
+ (rnd.nextGaussian() > 0.0, p)
+ }
+
+ val data1 = signedData.flatMap {
+ case (true, p) => Iterator(p, p)
+ case (false, p) => Iterator(p)
+ }
+
+ val weightedSignedData = signedData.flatMap {
+ case (true, LabeledPoint(label, features)) =>
+ Iterator(
+ Instance(label, weight = 1.2, features),
+ Instance(label, weight = 0.8, features)
+ )
+ case (false, LabeledPoint(label, features)) =>
+ Iterator(
+ Instance(label, weight = 0.3, features),
+ Instance(label, weight = 0.1, features),
+ Instance(label, weight = 0.6, features)
+ )
+ }
+
+ val noiseData = LinearDataGenerator.generateLinearInput(
+ 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1)
+ val weightedNoiseData = noiseData.map {
+ case LabeledPoint(label, features) => Instance(label, weight = 0, features)
+ }
+ val data2 = weightedSignedData ++ weightedNoiseData
+
+ (sqlContext.createDataFrame(sc.parallelize(data1, 4)),
+ sqlContext.createDataFrame(sc.parallelize(data2, 4)))
}
- val noiseData = LinearDataGenerator.generateLinearInput(
- 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1)
- val weightedNoiseData = noiseData.map {
- case LabeledPoint(label, features) => Instance(label, weight = 0, features)
- }
- val data2 = weightedSignedData ++ weightedNoiseData
-
- (sqlContext.createDataFrame(sc.parallelize(data1, 4)),
- sqlContext.createDataFrame(sc.parallelize(data2, 4)))
+ val trainer1a = (new LinearRegression).setFitIntercept(true)
+ .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver)
+ val trainer1b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight")
+ .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver)
+
+ // Normal optimizer is not supported with non-zero elasticnet parameter.
+ val model1a0 = trainer1a.fit(data)
+ val model1a1 = trainer1a.fit(weightedData)
+ val model1b = trainer1b.fit(weightedData)
+
+ assert(model1a0.weights !~= model1a1.weights absTol 1E-3)
+ assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3)
+ assert(model1a0.weights ~== model1b.weights absTol 1E-3)
+ assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
+
+ val trainer2a = (new LinearRegression).setFitIntercept(true)
+ .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver)
+ val trainer2b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight")
+ .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver)
+ val model2a0 = trainer2a.fit(data)
+ val model2a1 = trainer2a.fit(weightedData)
+ val model2b = trainer2b.fit(weightedData)
+ assert(model2a0.weights !~= model2a1.weights absTol 1E-3)
+ assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3)
+ assert(model2a0.weights ~== model2b.weights absTol 1E-3)
+ assert(model2a0.intercept ~== model2b.intercept absTol 1E-3)
+
+ val trainer3a = (new LinearRegression).setFitIntercept(false)
+ .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver)
+ val trainer3b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight")
+ .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver)
+ val model3a0 = trainer3a.fit(data)
+ val model3a1 = trainer3a.fit(weightedData)
+ val model3b = trainer3b.fit(weightedData)
+ assert(model3a0.weights !~= model3a1.weights absTol 1E-3)
+ assert(model3a0.weights ~== model3b.weights absTol 1E-3)
+
+ val trainer4a = (new LinearRegression).setFitIntercept(false)
+ .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver)
+ val trainer4b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight")
+ .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver)
+ val model4a0 = trainer4a.fit(data)
+ val model4a1 = trainer4a.fit(weightedData)
+ val model4b = trainer4b.fit(weightedData)
+ assert(model4a0.weights !~= model4a1.weights absTol 1E-3)
+ assert(model4a0.weights ~== model4b.weights absTol 1E-3)
}
-
- val trainer1a = (new LinearRegression).setFitIntercept(true)
- .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
- val trainer1b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight")
- .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
- val model1a0 = trainer1a.fit(data)
- val model1a1 = trainer1a.fit(weightedData)
- val model1b = trainer1b.fit(weightedData)
- assert(model1a0.weights !~= model1a1.weights absTol 1E-3)
- assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3)
- assert(model1a0.weights ~== model1b.weights absTol 1E-3)
- assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
-
- val trainer2a = (new LinearRegression).setFitIntercept(true)
- .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
- val trainer2b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight")
- .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
- val model2a0 = trainer2a.fit(data)
- val model2a1 = trainer2a.fit(weightedData)
- val model2b = trainer2b.fit(weightedData)
- assert(model2a0.weights !~= model2a1.weights absTol 1E-3)
- assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3)
- assert(model2a0.weights ~== model2b.weights absTol 1E-3)
- assert(model2a0.intercept ~== model2b.intercept absTol 1E-3)
-
- val trainer3a = (new LinearRegression).setFitIntercept(false)
- .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
- val trainer3b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight")
- .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
- val model3a0 = trainer3a.fit(data)
- val model3a1 = trainer3a.fit(weightedData)
- val model3b = trainer3b.fit(weightedData)
- assert(model3a0.weights !~= model3a1.weights absTol 1E-3)
- assert(model3a0.weights ~== model3b.weights absTol 1E-3)
-
- val trainer4a = (new LinearRegression).setFitIntercept(false)
- .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
- val trainer4b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight")
- .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
- val model4a0 = trainer4a.fit(data)
- val model4a1 = trainer4a.fit(weightedData)
- val model4b = trainer4b.fit(weightedData)
- assert(model4a0.weights !~= model4a1.weights absTol 1E-3)
- assert(model4a0.weights ~== model4b.weights absTol 1E-3)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index fde02e0c84..cbe09292a0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -69,7 +69,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
- val trainer = new LinearRegression
+ val trainer = new LinearRegression().setSolver("l-bfgs")
val lrParamMaps = new ParamGridBuilder()
.addGrid(trainer.regParam, Array(1000.0, 0.001))
.addGrid(trainer.maxIter, Array(0, 10))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index ef24e6fb6b..5fb80091d0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -58,7 +58,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
- val trainer = new LinearRegression
+ val trainer = new LinearRegression().setSolver("l-bfgs")
val lrParamMaps = new ParamGridBuilder()
.addGrid(trainer.regParam, Array(1000.0, 0.001))
.addGrid(trainer.maxIter, Array(0, 10))