aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala52
1 files changed, 46 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 475a308385..f66323e36c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -30,19 +30,59 @@ private[r] class GeneralizedLinearRegressionWrapper private (
private val glm: GeneralizedLinearRegressionModel =
pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
+ lazy val rFeatures: Array[String] = if (glm.getFitIntercept) {
+ Array("(Intercept)") ++ features
+ } else {
+ features
+ }
+
lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) {
- Array(glm.intercept) ++ glm.coefficients.toArray
+ Array(glm.intercept) ++ glm.coefficients.toArray ++
+ rCoefficientStandardErrors ++ rTValues ++ rPValues
} else {
- glm.coefficients.toArray
+ glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues
}
- lazy val rFeatures: Array[String] = if (glm.getFitIntercept) {
- Array("(Intercept)") ++ features
+ private lazy val rCoefficientStandardErrors = if (glm.getFitIntercept) {
+ Array(glm.summary.coefficientStandardErrors.last) ++
+ glm.summary.coefficientStandardErrors.dropRight(1)
} else {
- features
+ glm.summary.coefficientStandardErrors
+ }
+
+ private lazy val rTValues = if (glm.getFitIntercept) {
+ Array(glm.summary.tValues.last) ++ glm.summary.tValues.dropRight(1)
+ } else {
+ glm.summary.tValues
}
- def transform(dataset: DataFrame): DataFrame = {
+ private lazy val rPValues = if (glm.getFitIntercept) {
+ Array(glm.summary.pValues.last) ++ glm.summary.pValues.dropRight(1)
+ } else {
+ glm.summary.pValues
+ }
+
+ lazy val rDispersion: Double = glm.summary.dispersion
+
+ lazy val rNullDeviance: Double = glm.summary.nullDeviance
+
+ lazy val rDeviance: Double = glm.summary.deviance
+
+ lazy val rResidualDegreeOfFreedomNull: Long = glm.summary.residualDegreeOfFreedomNull
+
+ lazy val rResidualDegreeOfFreedom: Long = glm.summary.residualDegreeOfFreedom
+
+ lazy val rAic: Double = glm.summary.aic
+
+ lazy val rNumIterations: Int = glm.summary.numIterations
+
+ lazy val rDevianceResiduals: DataFrame = glm.summary.residuals()
+
+ lazy val rFamily: String = glm.getFamily
+
+ def residuals(residualsType: String): DataFrame = glm.summary.residuals(residualsType)
+
+ def transform(dataset: Dataset[_]): DataFrame = {
pipeline.transform(dataset).drop(glm.getFeaturesCol)
}
}