aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala105
1 files changed, 105 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 273c882c2a..81fc6603cc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -37,6 +37,8 @@ class LinearRegressionSuite
@transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _
@transient var datasetWithSparseFeature: DataFrame = _
@transient var datasetWithWeight: DataFrame = _
+ @transient var datasetWithWeightConstantLabel: DataFrame = _
+ @transient var datasetWithWeightZeroLabel: DataFrame = _
/*
In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
@@ -92,6 +94,29 @@ class LinearRegressionSuite
Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
), 2))
+
+ /*
+ R code:
+
+ A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+ b.const <- c(17, 17, 17, 17)
+ w <- c(1, 2, 3, 4)
+ df.const.label <- as.data.frame(cbind(A, b.const))
+ */
+ datasetWithWeightConstantLabel = sqlContext.createDataFrame(
+ sc.parallelize(Seq(
+ Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(17.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2))
+ datasetWithWeightZeroLabel = sqlContext.createDataFrame(
+ sc.parallelize(Seq(
+ Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(0.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2))
}
test("params") {
@@ -558,6 +583,86 @@ class LinearRegressionSuite
}
}
+ test("linear regression model with constant label") {
+ /*
+ R code:
+ for (formula in c(b.const ~ . -1, b.const ~ .)) {
+ model <- lm(formula, data=df.const.label, weights=w)
+ print(as.vector(coef(model)))
+ }
+ [1] -9.221298 3.394343
+ [1] 17 0 0
+ */
+ val expected = Seq(
+ Vectors.dense(0.0, -9.221298, 3.394343),
+ Vectors.dense(17.0, 0.0, 0.0))
+
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ var idx = 0
+ for (fitIntercept <- Seq(false, true)) {
+ val model1 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setWeightCol("weight")
+ .setSolver(solver)
+ .fit(datasetWithWeightConstantLabel)
+ val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0),
+ model1.coefficients(1))
+ assert(actual1 ~== expected(idx) absTol 1e-4)
+
+ val model2 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setWeightCol("weight")
+ .setSolver(solver)
+ .fit(datasetWithWeightZeroLabel)
+ val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0),
+ model2.coefficients(1))
+ assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4)
+ idx += 1
+ }
+ }
+ }
+
+ test("regularized linear regression through origin with constant label") {
+ // The problem is ill-defined if fitIntercept=false, regParam is non-zero.
+ // An exception is thrown in this case.
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ for (standardization <- Seq(false, true)) {
+ val model = new LinearRegression().setFitIntercept(false)
+ .setRegParam(0.1).setStandardization(standardization).setSolver(solver)
+ intercept[IllegalArgumentException] {
+ model.fit(datasetWithWeightConstantLabel)
+ }
+ }
+ }
+ }
+
+ test("linear regression with l-bfgs when training is not needed") {
+ // When label is constant, l-bfgs solver returns results without training.
+ // There are two possibilities: If the label is non-zero but constant,
+ // and fitIntercept is true, then the model return yMean as intercept without training.
+ // If label is all zeros, then all coefficients are zero regardless of fitIntercept, so
+ // no training is needed.
+ for (fitIntercept <- Seq(false, true)) {
+ for (standardization <- Seq(false, true)) {
+ val model1 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setStandardization(standardization)
+ .setWeightCol("weight")
+ .setSolver("l-bfgs")
+ .fit(datasetWithWeightConstantLabel)
+ if (fitIntercept) {
+ assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
+ }
+ val model2 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setWeightCol("weight")
+ .setSolver("l-bfgs")
+ .fit(datasetWithWeightZeroLabel)
+ assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
+ }
+ }
+ }
+
test("linear regression model training summary") {
Seq("auto", "l-bfgs", "normal").foreach { solver =>
val trainer = new LinearRegression().setSolver(solver)