aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala21
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala69
2 files changed, 83 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 8617722ae5..797870eb8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -86,6 +86,24 @@ private[ml] class WeightedLeastSquares(
val aaBar = summary.aaBar
val aaValues = aaBar.values
+ if (bStd == 0) {
+ if (fitIntercept) {
+ logWarning(s"The standard deviation of the label is zero, so the coefficients will be " +
+ s"zeros and the intercept will be the mean of the label; as a result, " +
+ s"training is not needed.")
+ val coefficients = new DenseVector(Array.ofDim(k-1))
+ val intercept = bBar
+ val diagInvAtWA = new DenseVector(Array(0D))
+ return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA)
+ } else {
+ require(!(regParam > 0.0 && standardizeLabel),
+ "The standard deviation of the label is zero. " +
+ "Model cannot be regularized with standardization=true")
+ logWarning(s"The standard deviation of the label is zero. " +
+ "Consider setting fitIntercept=true.")
+ }
+ }
+
// add regularization to diagonals
var i = 0
var j = 2
@@ -94,8 +112,7 @@ private[ml] class WeightedLeastSquares(
if (standardizeFeatures) {
lambda *= aVar(j - 2)
}
- if (standardizeLabel) {
- // TODO: handle the case when bStd = 0
+ if (standardizeLabel && bStd != 0) {
lambda /= bStd
}
aaValues(i) += lambda
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
index b542ba3dc5..0b58a9821f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD
class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext {
private var instances: RDD[Instance] = _
+ private var instancesConstLabel: RDD[Instance] = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -43,6 +44,20 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
), 2)
+
+ /*
+ R code:
+
+ A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+ b.const <- c(17, 17, 17, 17)
+ w <- c(1, 2, 3, 4)
+ */
+ instancesConstLabel = sc.parallelize(Seq(
+ Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(17.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2)
}
test("WLS against lm") {
@@ -65,15 +80,59 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
var idx = 0
for (fitIntercept <- Seq(false, true)) {
- val wls = new WeightedLeastSquares(
- fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false)
- .fit(instances)
- val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
- assert(actual ~== expected(idx) absTol 1e-4)
+ for (standardization <- Seq(false, true)) {
+ val wls = new WeightedLeastSquares(
+ fitIntercept, regParam = 0.0, standardizeFeatures = standardization,
+ standardizeLabel = standardization).fit(instances)
+ val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
+ assert(actual ~== expected(idx) absTol 1e-4)
+ }
+ idx += 1
+ }
+ }
+
+ test("WLS against lm when label is constant and no regularization") {
+ /*
+ R code:
+
+ df.const.label <- as.data.frame(cbind(A, b.const))
+ for (formula in c(b.const ~ . -1, b.const ~ .)) {
+ model <- lm(formula, data=df.const.label, weights=w)
+ print(as.vector(coef(model)))
+ }
+
+ [1] -9.221298 3.394343
+ [1] 17 0 0
+ */
+
+ val expected = Seq(
+ Vectors.dense(0.0, -9.221298, 3.394343),
+ Vectors.dense(17.0, 0.0, 0.0))
+
+ var idx = 0
+ for (fitIntercept <- Seq(false, true)) {
+ for (standardization <- Seq(false, true)) {
+ val wls = new WeightedLeastSquares(
+ fitIntercept, regParam = 0.0, standardizeFeatures = standardization,
+ standardizeLabel = standardization).fit(instancesConstLabel)
+ val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
+ assert(actual ~== expected(idx) absTol 1e-4)
+ }
idx += 1
}
}
+ test("WLS with regularization when label is constant") {
+ // if regParam is non-zero and standardization is true, the problem is ill-defined and
+ // an exception is thrown.
+ val wls = new WeightedLeastSquares(
+ fitIntercept = false, regParam = 0.1, standardizeFeatures = true,
+ standardizeLabel = true)
+ intercept[IllegalArgumentException]{
+ wls.fit(instancesConstLabel)
+ }
+ }
+
test("WLS against glmnet") {
/*
R code: