aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorImran Younus <iyounus@us.ibm.com>2016-01-20 11:16:59 -0800
committerXiangrui Meng <meng@databricks.com>2016-01-20 11:16:59 -0800
commit9753835cf3acc135e61bf668223046e29306c80d (patch)
tree0bb568468afea635647eeb7df8e30453a2f95489 /mllib
parent9bb35c5b59e58dbebbdc6856d611bff73dd35a91 (diff)
downloadspark-9753835cf3acc135e61bf668223046e29306c80d.tar.gz
spark-9753835cf3acc135e61bf668223046e29306c80d.tar.bz2
spark-9753835cf3acc135e61bf668223046e29306c80d.zip
[SPARK-12230][ML] WeightedLeastSquares.fit() should handle division by zero properly if standard deviation of target variable is zero.
This fixes the behavior of WeightedLeastSquars.fit() when the standard deviation of the target variable is zero. If the fitIntercept is true, there is no need to train. Author: Imran Younus <iyounus@us.ibm.com> Closes #10274 from iyounus/SPARK-12230_bug_fix_in_weighted_least_squares.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala21
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala69
2 files changed, 83 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 8617722ae5..797870eb8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -86,6 +86,24 @@ private[ml] class WeightedLeastSquares(
val aaBar = summary.aaBar
val aaValues = aaBar.values
+ if (bStd == 0) {
+ if (fitIntercept) {
+ logWarning(s"The standard deviation of the label is zero, so the coefficients will be " +
+ s"zeros and the intercept will be the mean of the label; as a result, " +
+ s"training is not needed.")
+ val coefficients = new DenseVector(Array.ofDim(k-1))
+ val intercept = bBar
+ val diagInvAtWA = new DenseVector(Array(0D))
+ return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA)
+ } else {
+ require(!(regParam > 0.0 && standardizeLabel),
+ "The standard deviation of the label is zero. " +
+ "Model cannot be regularized with standardization=true")
+ logWarning(s"The standard deviation of the label is zero. " +
+ "Consider setting fitIntercept=true.")
+ }
+ }
+
// add regularization to diagonals
var i = 0
var j = 2
@@ -94,8 +112,7 @@ private[ml] class WeightedLeastSquares(
if (standardizeFeatures) {
lambda *= aVar(j - 2)
}
- if (standardizeLabel) {
- // TODO: handle the case when bStd = 0
+ if (standardizeLabel && bStd != 0) {
lambda /= bStd
}
aaValues(i) += lambda
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
index b542ba3dc5..0b58a9821f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD
class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext {
private var instances: RDD[Instance] = _
+ private var instancesConstLabel: RDD[Instance] = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -43,6 +44,20 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
), 2)
+
+ /*
+ R code:
+
+ A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+ b.const <- c(17, 17, 17, 17)
+ w <- c(1, 2, 3, 4)
+ */
+ instancesConstLabel = sc.parallelize(Seq(
+ Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(17.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2)
}
test("WLS against lm") {
@@ -65,15 +80,59 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
var idx = 0
for (fitIntercept <- Seq(false, true)) {
- val wls = new WeightedLeastSquares(
- fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false)
- .fit(instances)
- val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
- assert(actual ~== expected(idx) absTol 1e-4)
+ for (standardization <- Seq(false, true)) {
+ val wls = new WeightedLeastSquares(
+ fitIntercept, regParam = 0.0, standardizeFeatures = standardization,
+ standardizeLabel = standardization).fit(instances)
+ val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
+ assert(actual ~== expected(idx) absTol 1e-4)
+ }
+ idx += 1
+ }
+ }
+
+ test("WLS against lm when label is constant and no regularization") {
+ /*
+ R code:
+
+ df.const.label <- as.data.frame(cbind(A, b.const))
+ for (formula in c(b.const ~ . -1, b.const ~ .)) {
+ model <- lm(formula, data=df.const.label, weights=w)
+ print(as.vector(coef(model)))
+ }
+
+ [1] -9.221298 3.394343
+ [1] 17 0 0
+ */
+
+ val expected = Seq(
+ Vectors.dense(0.0, -9.221298, 3.394343),
+ Vectors.dense(17.0, 0.0, 0.0))
+
+ var idx = 0
+ for (fitIntercept <- Seq(false, true)) {
+ for (standardization <- Seq(false, true)) {
+ val wls = new WeightedLeastSquares(
+ fitIntercept, regParam = 0.0, standardizeFeatures = standardization,
+ standardizeLabel = standardization).fit(instancesConstLabel)
+ val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
+ assert(actual ~== expected(idx) absTol 1e-4)
+ }
idx += 1
}
}
+ test("WLS with regularization when label is constant") {
+ // if regParam is non-zero and standardization is true, the problem is ill-defined and
+ // an exception is thrown.
+ val wls = new WeightedLeastSquares(
+ fitIntercept = false, regParam = 0.1, standardizeFeatures = true,
+ standardizeLabel = true)
+ intercept[IllegalArgumentException]{
+ wls.fit(instancesConstLabel)
+ }
+ }
+
test("WLS against glmnet") {
/*
R code: