From f6314eab4b494bd5b5e9e41c6f582d4f22c0967a Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Tue, 14 Mar 2017 00:50:38 -0700 Subject: [SPARK-19391][SPARKR][ML] Tweedie GLM API for SparkR ## What changes were proposed in this pull request? Port Tweedie GLM #16344 to SparkR felixcheung yanboliang ## How was this patch tested? new test in SparkR Author: actuaryzhang Closes #16729 from actuaryzhang/sparkRTweedie. --- .../ml/r/GeneralizedLinearRegressionWrapper.scala | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index cbd6cd1c79..c49416b240 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -71,7 +71,9 @@ private[r] object GeneralizedLinearRegressionWrapper tol: Double, maxIter: Int, weightCol: String, - regParam: Double): GeneralizedLinearRegressionWrapper = { + regParam: Double, + variancePower: Double, + linkPower: Double): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula().setFormula(formula) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) @@ -83,13 +85,17 @@ private[r] object GeneralizedLinearRegressionWrapper // assemble and fit the pipeline val glr = new GeneralizedLinearRegression() .setFamily(family) - .setLink(link) .setFitIntercept(rFormula.hasIntercept) .setTol(tol) .setMaxIter(maxIter) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) - + // set variancePower and linkPower if family is tweedie; otherwise, set link function + if (family.toLowerCase == "tweedie") { + glr.setVariancePower(variancePower).setLinkPower(linkPower) + } else { + glr.setLink(link) + } if (weightCol != null) glr.setWeightCol(weightCol) val pipeline = new Pipeline() @@ -145,7 +151,12 @@ private[r] object GeneralizedLinearRegressionWrapper val rDeviance: Double = summary.deviance val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom - val rAic: Double = summary.aic + val rAic: Double = if (family.toLowerCase == "tweedie" && + !Array(0.0, 1.0, 2.0).exists(x => math.abs(x - variancePower) < 1e-8)) { + 0.0 + } else { + summary.aic + } val rNumIterations: Int = summary.numIterations new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion, -- cgit v1.2.3