aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala19
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala20
2 files changed, 35 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala
index e4494792bb..08f8f19c1e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala
@@ -36,8 +36,7 @@ private[spark] object CholeskyDecomposition {
val k = bx.length
val info = new intW(0)
lapack.dppsv("U", k, 1, A, bx, k, info)
- val code = info.`val`
- assert(code == 0, s"lapack.dppsv returned $code.")
+ checkReturnValue(info, "dppsv")
bx
}
@@ -52,8 +51,20 @@ private[spark] object CholeskyDecomposition {
def inverse(UAi: Array[Double], k: Int): Array[Double] = {
val info = new intW(0)
lapack.dpptri("U", k, UAi, info)
- val code = info.`val`
- assert(code == 0, s"lapack.dpptri returned $code.")
+ checkReturnValue(info, "dpptri")
UAi
}
+
+ private def checkReturnValue(info: intW, method: String): Unit = {
+ info.`val` match {
+ case code if code < 0 =>
+ throw new IllegalStateException(s"LAPACK.$method returned $code; arg ${-code} is illegal")
+ case code if code > 0 =>
+ throw new IllegalArgumentException(
+ s"LAPACK.$method returned $code because A is not positive definite. Is A derived from " +
+ "a singular matrix (e.g. collinear column values)?")
+ case _ => // do nothing
+ }
+ }
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
index c8de796b2d..2cb1af0dee 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
@@ -60,6 +60,26 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
), 2)
}
+ test("two collinear features result in error with no regularization") {
+ val singularInstances = sc.parallelize(Seq(
+ Instance(1.0, 1.0, Vectors.dense(1.0, 2.0)),
+ Instance(2.0, 1.0, Vectors.dense(2.0, 4.0)),
+ Instance(3.0, 1.0, Vectors.dense(3.0, 6.0)),
+ Instance(4.0, 1.0, Vectors.dense(4.0, 8.0))
+ ), 2)
+
+ intercept[IllegalArgumentException] {
+ new WeightedLeastSquares(
+ false, regParam = 0.0, standardizeFeatures = false,
+ standardizeLabel = false).fit(singularInstances)
+ }
+
+ // Should not throw an exception
+ new WeightedLeastSquares(
+ false, regParam = 1.0, standardizeFeatures = false,
+ standardizeLabel = false).fit(singularInstances)
+ }
+
test("WLS against lm") {
/*
R code: