aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-11-22 19:17:48 -0800
committerYanbo Liang <ybliang8@gmail.com>2016-11-22 19:17:48 -0800
commit982b82e32e0fc7d30c5d557944a79eb3e6d2da59 (patch)
treead9a9b28a282ffe8a4fe06c65abba8acd3339fb3 /mllib/src
parentd0212eb0f22473ee5482fe98dafc24e16ffcfc63 (diff)
downloadspark-982b82e32e0fc7d30c5d557944a79eb3e6d2da59.tar.gz
spark-982b82e32e0fc7d30c5d557944a79eb3e6d2da59.tar.bz2
spark-982b82e32e0fc7d30c5d557944a79eb3e6d2da59.zip
[SPARK-18501][ML][SPARKR] Fix spark.glm errors when fitting on collinear data
## What changes were proposed in this pull request? * Fix SparkR ```spark.glm``` errors when fitting on collinear data, since ```standard error of coefficients, t value and p value``` are not available in this condition. * Scala/Python GLM summary should throw exception if users get ```standard error of coefficients, t value and p value``` but the underlying WLS was solved by local "l-bfgs". ## How was this patch tested? Add unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #15930 from yanboliang/spark-18501.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala54
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala46
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala21
3 files changed, 90 insertions, 31 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index add4d49110..8bcc9fe5d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -144,30 +144,38 @@ private[r] object GeneralizedLinearRegressionWrapper
features
}
- val rCoefficientStandardErrors = if (glm.getFitIntercept) {
- Array(summary.coefficientStandardErrors.last) ++
- summary.coefficientStandardErrors.dropRight(1)
+ val rCoefficients: Array[Double] = if (summary.isNormalSolver) {
+ val rCoefficientStandardErrors = if (glm.getFitIntercept) {
+ Array(summary.coefficientStandardErrors.last) ++
+ summary.coefficientStandardErrors.dropRight(1)
+ } else {
+ summary.coefficientStandardErrors
+ }
+
+ val rTValues = if (glm.getFitIntercept) {
+ Array(summary.tValues.last) ++ summary.tValues.dropRight(1)
+ } else {
+ summary.tValues
+ }
+
+ val rPValues = if (glm.getFitIntercept) {
+ Array(summary.pValues.last) ++ summary.pValues.dropRight(1)
+ } else {
+ summary.pValues
+ }
+
+ if (glm.getFitIntercept) {
+ Array(glm.intercept) ++ glm.coefficients.toArray ++
+ rCoefficientStandardErrors ++ rTValues ++ rPValues
+ } else {
+ glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues
+ }
} else {
- summary.coefficientStandardErrors
- }
-
- val rTValues = if (glm.getFitIntercept) {
- Array(summary.tValues.last) ++ summary.tValues.dropRight(1)
- } else {
- summary.tValues
- }
-
- val rPValues = if (glm.getFitIntercept) {
- Array(summary.pValues.last) ++ summary.pValues.dropRight(1)
- } else {
- summary.pValues
- }
-
- val rCoefficients: Array[Double] = if (glm.getFitIntercept) {
- Array(glm.intercept) ++ glm.coefficients.toArray ++
- rCoefficientStandardErrors ++ rTValues ++ rPValues
- } else {
- glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues
+ if (glm.getFitIntercept) {
+ Array(glm.intercept) ++ glm.coefficients.toArray
+ } else {
+ glm.coefficients.toArray
+ }
}
val rDispersion: Double = summary.dispersion
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 3f9de1fe74..f33dd0fd29 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -1064,44 +1064,74 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] (
import GeneralizedLinearRegression._
/**
+ * Whether the underlying [[WeightedLeastSquares]] using the "normal" solver.
+ */
+ private[ml] val isNormalSolver: Boolean = {
+ diagInvAtWA.length != 1 || diagInvAtWA(0) != 0
+ }
+
+ /**
* Standard error of estimated coefficients and intercept.
+ * This value is only available when the underlying [[WeightedLeastSquares]]
+ * using the "normal" solver.
*
* If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
* then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val coefficientStandardErrors: Array[Double] = {
- diagInvAtWA.map(_ * dispersion).map(math.sqrt)
+ if (isNormalSolver) {
+ diagInvAtWA.map(_ * dispersion).map(math.sqrt)
+ } else {
+ throw new UnsupportedOperationException(
+ "No Std. Error of coefficients available for this GeneralizedLinearRegressionModel")
+ }
}
/**
* T-statistic of estimated coefficients and intercept.
+ * This value is only available when the underlying [[WeightedLeastSquares]]
+ * using the "normal" solver.
*
* If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
* then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val tValues: Array[Double] = {
- val estimate = if (model.getFitIntercept) {
- Array.concat(model.coefficients.toArray, Array(model.intercept))
+ if (isNormalSolver) {
+ val estimate = if (model.getFitIntercept) {
+ Array.concat(model.coefficients.toArray, Array(model.intercept))
+ } else {
+ model.coefficients.toArray
+ }
+ estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 }
} else {
- model.coefficients.toArray
+ throw new UnsupportedOperationException(
+ "No t-statistic available for this GeneralizedLinearRegressionModel")
}
- estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 }
}
/**
* Two-sided p-value of estimated coefficients and intercept.
+ * This value is only available when the underlying [[WeightedLeastSquares]]
+ * using the "normal" solver.
*
* If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
* then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val pValues: Array[Double] = {
- if (model.getFamily == Binomial.name || model.getFamily == Poisson.name) {
- tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) }
+ if (isNormalSolver) {
+ if (model.getFamily == Binomial.name || model.getFamily == Poisson.name) {
+ tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) }
+ } else {
+ tValues.map { x =>
+ 2.0 * (1.0 - dist.StudentsT(degreesOfFreedom.toDouble).cdf(math.abs(x)))
+ }
+ }
} else {
- tValues.map { x => 2.0 * (1.0 - dist.StudentsT(degreesOfFreedom.toDouble).cdf(math.abs(x))) }
+ throw new UnsupportedOperationException(
+ "No p-value available for this GeneralizedLinearRegressionModel")
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 9b0fa67630..4fab216033 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -1048,6 +1048,27 @@ class GeneralizedLinearRegressionSuite
assert(summary.solver === "irls")
}
+ test("glm handle collinear features") {
+ val collinearInstances = Seq(
+ Instance(1.0, 1.0, Vectors.dense(1.0, 2.0)),
+ Instance(2.0, 1.0, Vectors.dense(2.0, 4.0)),
+ Instance(3.0, 1.0, Vectors.dense(3.0, 6.0)),
+ Instance(4.0, 1.0, Vectors.dense(4.0, 8.0))
+ ).toDF()
+ val trainer = new GeneralizedLinearRegression()
+ val model = trainer.fit(collinearInstances)
+ // to make it clear that underlying WLS did not solve analytically
+ intercept[UnsupportedOperationException] {
+ model.summary.coefficientStandardErrors
+ }
+ intercept[UnsupportedOperationException] {
+ model.summary.pValues
+ }
+ intercept[UnsupportedOperationException] {
+ model.summary.tValues
+ }
+ }
+
test("read/write") {
def checkModelData(
model: GeneralizedLinearRegressionModel,