aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorImran Younus <iyounus@us.ibm.com>2016-02-02 20:38:53 -0800
committerDB Tsai <dbt@netflix.com>2016-02-02 20:38:53 -0800
commit0557146619868002e2f7ec3c121c30bbecc918fc (patch)
tree184e05350cfe43f7b57c2b62cdf58398e4f6ea73 /mllib
parent99a6e3c1e8d580ce1cc497bd9362eaf16c597f77 (diff)
downloadspark-0557146619868002e2f7ec3c121c30bbecc918fc.tar.gz
spark-0557146619868002e2f7ec3c121c30bbecc918fc.tar.bz2
spark-0557146619868002e2f7ec3c121c30bbecc918fc.zip
[SPARK-12732][ML] bug fix in linear regression train
Fixed the bug in linear regression train for the case when the target variable is constant. The two cases for `fitIntercept=true` or `fitIntercept=false` should be treated differently. Author: Imran Younus <iyounus@us.ibm.com> Closes #10702 from iyounus/SPARK-12732_bug_fix_in_linear_regression_train.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala66
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala105
2 files changed, 146 insertions, 25 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index c54e08b2ad..e253f25c0e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -219,33 +219,49 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
}
val yMean = ySummarizer.mean(0)
- val yStd = math.sqrt(ySummarizer.variance(0))
-
- // If the yStd is zero, then the intercept is yMean with zero coefficient;
- // as a result, training is not needed.
- if (yStd == 0.0) {
- logWarning(s"The standard deviation of the label is zero, so the coefficients will be " +
- s"zeros and the intercept will be the mean of the label; as a result, " +
- s"training is not needed.")
- if (handlePersistence) instances.unpersist()
- val coefficients = Vectors.sparse(numFeatures, Seq())
- val intercept = yMean
-
- val model = new LinearRegressionModel(uid, coefficients, intercept)
- // Handle possible missing or invalid prediction columns
- val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
-
- val trainingSummary = new LinearRegressionTrainingSummary(
- summaryModel.transform(dataset),
- predictionColName,
- $(labelCol),
- model,
- Array(0D),
- $(featuresCol),
- Array(0D))
- return copyValues(model.setSummary(trainingSummary))
+ val rawYStd = math.sqrt(ySummarizer.variance(0))
+ if (rawYStd == 0.0) {
+ if ($(fitIntercept) || yMean==0.0) {
+ // If the rawYStd is zero and fitIntercept=true, then the intercept is yMean with
+ // zero coefficient; as a result, training is not needed.
+ // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of
+ // the fitIntercept
+ if (yMean == 0.0) {
+ logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " +
+ s"and the intercept will all be zero; as a result, training is not needed.")
+ } else {
+ logWarning(s"The standard deviation of the label is zero, so the coefficients will be " +
+ s"zeros and the intercept will be the mean of the label; as a result, " +
+ s"training is not needed.")
+ }
+ if (handlePersistence) instances.unpersist()
+ val coefficients = Vectors.sparse(numFeatures, Seq())
+ val intercept = yMean
+
+ val model = new LinearRegressionModel(uid, coefficients, intercept)
+ // Handle possible missing or invalid prediction columns
+ val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
+
+ val trainingSummary = new LinearRegressionTrainingSummary(
+ summaryModel.transform(dataset),
+ predictionColName,
+ $(labelCol),
+ model,
+ Array(0D),
+ $(featuresCol),
+ Array(0D))
+ return copyValues(model.setSummary(trainingSummary))
+ } else {
+ require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
+ "Model cannot be regularized.")
+ logWarning(s"The standard deviation of the label is zero. " +
+ "Consider setting fitIntercept=true.")
+ }
}
+ // if y is constant (rawYStd is zero), then y cannot be scaled. In this case
+ // setting yStd=1.0 ensures that y is not scaled anymore in l-bfgs algorithm.
+ val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean)
val featuresMean = featuresSummarizer.mean.toArray
val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 273c882c2a..81fc6603cc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -37,6 +37,8 @@ class LinearRegressionSuite
@transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _
@transient var datasetWithSparseFeature: DataFrame = _
@transient var datasetWithWeight: DataFrame = _
+ @transient var datasetWithWeightConstantLabel: DataFrame = _
+ @transient var datasetWithWeightZeroLabel: DataFrame = _
/*
In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
@@ -92,6 +94,29 @@ class LinearRegressionSuite
Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
), 2))
+
+ /*
+ R code:
+
+ A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+ b.const <- c(17, 17, 17, 17)
+ w <- c(1, 2, 3, 4)
+ df.const.label <- as.data.frame(cbind(A, b.const))
+ */
+ datasetWithWeightConstantLabel = sqlContext.createDataFrame(
+ sc.parallelize(Seq(
+ Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(17.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2))
+ datasetWithWeightZeroLabel = sqlContext.createDataFrame(
+ sc.parallelize(Seq(
+ Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(0.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2))
}
test("params") {
@@ -558,6 +583,86 @@ class LinearRegressionSuite
}
}
+ test("linear regression model with constant label") {
+ /*
+ R code:
+ for (formula in c(b.const ~ . -1, b.const ~ .)) {
+ model <- lm(formula, data=df.const.label, weights=w)
+ print(as.vector(coef(model)))
+ }
+ [1] -9.221298 3.394343
+ [1] 17 0 0
+ */
+ val expected = Seq(
+ Vectors.dense(0.0, -9.221298, 3.394343),
+ Vectors.dense(17.0, 0.0, 0.0))
+
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ var idx = 0
+ for (fitIntercept <- Seq(false, true)) {
+ val model1 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setWeightCol("weight")
+ .setSolver(solver)
+ .fit(datasetWithWeightConstantLabel)
+ val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0),
+ model1.coefficients(1))
+ assert(actual1 ~== expected(idx) absTol 1e-4)
+
+ val model2 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setWeightCol("weight")
+ .setSolver(solver)
+ .fit(datasetWithWeightZeroLabel)
+ val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0),
+ model2.coefficients(1))
+ assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4)
+ idx += 1
+ }
+ }
+ }
+
+ test("regularized linear regression through origin with constant label") {
+ // The problem is ill-defined if fitIntercept=false, regParam is non-zero.
+ // An exception is thrown in this case.
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ for (standardization <- Seq(false, true)) {
+ val model = new LinearRegression().setFitIntercept(false)
+ .setRegParam(0.1).setStandardization(standardization).setSolver(solver)
+ intercept[IllegalArgumentException] {
+ model.fit(datasetWithWeightConstantLabel)
+ }
+ }
+ }
+ }
+
+ test("linear regression with l-bfgs when training is not needed") {
+ // When label is constant, l-bfgs solver returns results without training.
+ // There are two possibilities: If the label is non-zero but constant,
+ // and fitIntercept is true, then the model return yMean as intercept without training.
+ // If label is all zeros, then all coefficients are zero regardless of fitIntercept, so
+ // no training is needed.
+ for (fitIntercept <- Seq(false, true)) {
+ for (standardization <- Seq(false, true)) {
+ val model1 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setStandardization(standardization)
+ .setWeightCol("weight")
+ .setSolver("l-bfgs")
+ .fit(datasetWithWeightConstantLabel)
+ if (fitIntercept) {
+ assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
+ }
+ val model2 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setWeightCol("weight")
+ .setSolver("l-bfgs")
+ .fit(datasetWithWeightZeroLabel)
+ assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
+ }
+ }
+ }
+
test("linear regression model training summary") {
Seq("auto", "l-bfgs", "normal").foreach { solver =>
val trainer = new LinearRegression().setSolver(solver)