aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorImran Younus <iyounus@us.ibm.com>2016-02-02 20:38:53 -0800
committerDB Tsai <dbt@netflix.com>2016-02-02 20:38:53 -0800
commit0557146619868002e2f7ec3c121c30bbecc918fc (patch)
tree184e05350cfe43f7b57c2b62cdf58398e4f6ea73 /mllib/src/test/scala/org/apache
parent99a6e3c1e8d580ce1cc497bd9362eaf16c597f77 (diff)
downloadspark-0557146619868002e2f7ec3c121c30bbecc918fc.tar.gz
spark-0557146619868002e2f7ec3c121c30bbecc918fc.tar.bz2
spark-0557146619868002e2f7ec3c121c30bbecc918fc.zip
[SPARK-12732][ML] bug fix in linear regression train
Fixed the bug in linear regression train for the case when the target variable is constant. The two cases for `fitIntercept=true` or `fitIntercept=false` should be treated differently. Author: Imran Younus <iyounus@us.ibm.com> Closes #10702 from iyounus/SPARK-12732_bug_fix_in_linear_regression_train.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala105
1 files changed, 105 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 273c882c2a..81fc6603cc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -37,6 +37,8 @@ class LinearRegressionSuite
@transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _
@transient var datasetWithSparseFeature: DataFrame = _
@transient var datasetWithWeight: DataFrame = _
+ @transient var datasetWithWeightConstantLabel: DataFrame = _
+ @transient var datasetWithWeightZeroLabel: DataFrame = _
/*
In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
@@ -92,6 +94,29 @@ class LinearRegressionSuite
Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
), 2))
+
+ /*
+ R code:
+
+ A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+ b.const <- c(17, 17, 17, 17)
+ w <- c(1, 2, 3, 4)
+ df.const.label <- as.data.frame(cbind(A, b.const))
+ */
+ datasetWithWeightConstantLabel = sqlContext.createDataFrame(
+ sc.parallelize(Seq(
+ Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(17.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2))
+ datasetWithWeightZeroLabel = sqlContext.createDataFrame(
+ sc.parallelize(Seq(
+ Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(0.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2))
}
test("params") {
@@ -558,6 +583,86 @@ class LinearRegressionSuite
}
}
+ test("linear regression model with constant label") {
+ /*
+ R code:
+ for (formula in c(b.const ~ . -1, b.const ~ .)) {
+ model <- lm(formula, data=df.const.label, weights=w)
+ print(as.vector(coef(model)))
+ }
+ [1] -9.221298 3.394343
+ [1] 17 0 0
+ */
+ val expected = Seq(
+ Vectors.dense(0.0, -9.221298, 3.394343),
+ Vectors.dense(17.0, 0.0, 0.0))
+
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ var idx = 0
+ for (fitIntercept <- Seq(false, true)) {
+ val model1 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setWeightCol("weight")
+ .setSolver(solver)
+ .fit(datasetWithWeightConstantLabel)
+ val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0),
+ model1.coefficients(1))
+ assert(actual1 ~== expected(idx) absTol 1e-4)
+
+ val model2 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setWeightCol("weight")
+ .setSolver(solver)
+ .fit(datasetWithWeightZeroLabel)
+ val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0),
+ model2.coefficients(1))
+ assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4)
+ idx += 1
+ }
+ }
+ }
+
+ test("regularized linear regression through origin with constant label") {
+ // The problem is ill-defined if fitIntercept=false, regParam is non-zero.
+ // An exception is thrown in this case.
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ for (standardization <- Seq(false, true)) {
+ val model = new LinearRegression().setFitIntercept(false)
+ .setRegParam(0.1).setStandardization(standardization).setSolver(solver)
+ intercept[IllegalArgumentException] {
+ model.fit(datasetWithWeightConstantLabel)
+ }
+ }
+ }
+ }
+
+ test("linear regression with l-bfgs when training is not needed") {
+ // When label is constant, l-bfgs solver returns results without training.
+ // There are two possibilities: If the label is non-zero but constant,
+ // and fitIntercept is true, then the model return yMean as intercept without training.
+ // If label is all zeros, then all coefficients are zero regardless of fitIntercept, so
+ // no training is needed.
+ for (fitIntercept <- Seq(false, true)) {
+ for (standardization <- Seq(false, true)) {
+ val model1 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setStandardization(standardization)
+ .setWeightCol("weight")
+ .setSolver("l-bfgs")
+ .fit(datasetWithWeightConstantLabel)
+ if (fitIntercept) {
+ assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
+ }
+ val model2 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setWeightCol("weight")
+ .setSolver("l-bfgs")
+ .fit(datasetWithWeightZeroLabel)
+ assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
+ }
+ }
+ }
+
test("linear regression model training summary") {
Seq("auto", "l-bfgs", "normal").foreach { solver =>
val trainer = new LinearRegression().setSolver(solver)