diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala | 52 |
1 files changed, 46 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 475a308385..f66323e36c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -30,19 +30,59 @@ private[r] class GeneralizedLinearRegressionWrapper private ( private val glm: GeneralizedLinearRegressionModel = pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] + lazy val rFeatures: Array[String] = if (glm.getFitIntercept) { + Array("(Intercept)") ++ features + } else { + features + } + lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) { - Array(glm.intercept) ++ glm.coefficients.toArray + Array(glm.intercept) ++ glm.coefficients.toArray ++ + rCoefficientStandardErrors ++ rTValues ++ rPValues } else { - glm.coefficients.toArray + glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues } - lazy val rFeatures: Array[String] = if (glm.getFitIntercept) { - Array("(Intercept)") ++ features + private lazy val rCoefficientStandardErrors = if (glm.getFitIntercept) { + Array(glm.summary.coefficientStandardErrors.last) ++ + glm.summary.coefficientStandardErrors.dropRight(1) } else { - features + glm.summary.coefficientStandardErrors + } + + private lazy val rTValues = if (glm.getFitIntercept) { + Array(glm.summary.tValues.last) ++ glm.summary.tValues.dropRight(1) + } else { + glm.summary.tValues } - def transform(dataset: DataFrame): DataFrame = { + private lazy val rPValues = if (glm.getFitIntercept) { + Array(glm.summary.pValues.last) ++ glm.summary.pValues.dropRight(1) + } else { + glm.summary.pValues + } + + lazy val rDispersion: Double = glm.summary.dispersion + + lazy val rNullDeviance: Double = glm.summary.nullDeviance + + lazy val rDeviance: Double = glm.summary.deviance + + lazy val rResidualDegreeOfFreedomNull: Long = glm.summary.residualDegreeOfFreedomNull + + lazy val rResidualDegreeOfFreedom: Long = glm.summary.residualDegreeOfFreedom + + lazy val rAic: Double = glm.summary.aic + + lazy val rNumIterations: Int = glm.summary.numIterations + + lazy val rDevianceResiduals: DataFrame = glm.summary.residuals() + + lazy val rFamily: String = glm.getFamily + + def residuals(residualsType: String): DataFrame = glm.summary.residuals(residualsType) + + def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset).drop(glm.getFeaturesCol) } } |