diff options
author | actuaryzhang <actuaryzhang10@gmail.com> | 2017-03-14 00:50:38 -0700 |
---|---|---|
committer | Felix Cheung <felixcheung@apache.org> | 2017-03-14 00:50:38 -0700 |
commit | f6314eab4b494bd5b5e9e41c6f582d4f22c0967a (patch) | |
tree | ff067df4be9eb6f3b660abf8332136d778201146 /mllib | |
parent | 415f9f3423aacc395097e40427364c921a2ed7f1 (diff) | |
download | spark-f6314eab4b494bd5b5e9e41c6f582d4f22c0967a.tar.gz spark-f6314eab4b494bd5b5e9e41c6f582d4f22c0967a.tar.bz2 spark-f6314eab4b494bd5b5e9e41c6f582d4f22c0967a.zip |
[SPARK-19391][SPARKR][ML] Tweedie GLM API for SparkR
## What changes were proposed in this pull request?
Port Tweedie GLM #16344 to SparkR
felixcheung yanboliang
## How was this patch tested?
new test in SparkR
Author: actuaryzhang <actuaryzhang10@gmail.com>
Closes #16729 from actuaryzhang/sparkRTweedie.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index cbd6cd1c79..c49416b240 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -71,7 +71,9 @@ private[r] object GeneralizedLinearRegressionWrapper tol: Double, maxIter: Int, weightCol: String, - regParam: Double): GeneralizedLinearRegressionWrapper = { + regParam: Double, + variancePower: Double, + linkPower: Double): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula().setFormula(formula) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) @@ -83,13 +85,17 @@ private[r] object GeneralizedLinearRegressionWrapper // assemble and fit the pipeline val glr = new GeneralizedLinearRegression() .setFamily(family) - .setLink(link) .setFitIntercept(rFormula.hasIntercept) .setTol(tol) .setMaxIter(maxIter) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) - + // set variancePower and linkPower if family is tweedie; otherwise, set link function + if (family.toLowerCase == "tweedie") { + glr.setVariancePower(variancePower).setLinkPower(linkPower) + } else { + glr.setLink(link) + } if (weightCol != null) glr.setWeightCol(weightCol) val pipeline = new Pipeline() @@ -145,7 +151,12 @@ private[r] object GeneralizedLinearRegressionWrapper val rDeviance: Double = summary.deviance val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom - val rAic: Double = summary.aic + val rAic: Double = if (family.toLowerCase == "tweedie" && + !Array(0.0, 1.0, 2.0).exists(x => math.abs(x - variancePower) < 1e-8)) { + 0.0 + } else { + summary.aic + } val rNumIterations: Int = summary.numIterations new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion, |