aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authoractuaryzhang <actuaryzhang10@gmail.com>2017-03-14 00:50:38 -0700
committerFelix Cheung <felixcheung@apache.org>2017-03-14 00:50:38 -0700
commitf6314eab4b494bd5b5e9e41c6f582d4f22c0967a (patch)
treeff067df4be9eb6f3b660abf8332136d778201146 /mllib
parent415f9f3423aacc395097e40427364c921a2ed7f1 (diff)
downloadspark-f6314eab4b494bd5b5e9e41c6f582d4f22c0967a.tar.gz
spark-f6314eab4b494bd5b5e9e41c6f582d4f22c0967a.tar.bz2
spark-f6314eab4b494bd5b5e9e41c6f582d4f22c0967a.zip
[SPARK-19391][SPARKR][ML] Tweedie GLM API for SparkR
## What changes were proposed in this pull request? Port Tweedie GLM #16344 to SparkR felixcheung yanboliang ## How was this patch tested? new test in SparkR Author: actuaryzhang <actuaryzhang10@gmail.com> Closes #16729 from actuaryzhang/sparkRTweedie.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala19
1 files changed, 15 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index cbd6cd1c79..c49416b240 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -71,7 +71,9 @@ private[r] object GeneralizedLinearRegressionWrapper
tol: Double,
maxIter: Int,
weightCol: String,
- regParam: Double): GeneralizedLinearRegressionWrapper = {
+ regParam: Double,
+ variancePower: Double,
+ linkPower: Double): GeneralizedLinearRegressionWrapper = {
val rFormula = new RFormula().setFormula(formula)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
@@ -83,13 +85,17 @@ private[r] object GeneralizedLinearRegressionWrapper
// assemble and fit the pipeline
val glr = new GeneralizedLinearRegression()
.setFamily(family)
- .setLink(link)
.setFitIntercept(rFormula.hasIntercept)
.setTol(tol)
.setMaxIter(maxIter)
.setRegParam(regParam)
.setFeaturesCol(rFormula.getFeaturesCol)
-
+ // set variancePower and linkPower if family is tweedie; otherwise, set link function
+ if (family.toLowerCase == "tweedie") {
+ glr.setVariancePower(variancePower).setLinkPower(linkPower)
+ } else {
+ glr.setLink(link)
+ }
if (weightCol != null) glr.setWeightCol(weightCol)
val pipeline = new Pipeline()
@@ -145,7 +151,12 @@ private[r] object GeneralizedLinearRegressionWrapper
val rDeviance: Double = summary.deviance
val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull
val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom
- val rAic: Double = summary.aic
+ val rAic: Double = if (family.toLowerCase == "tweedie" &&
+ !Array(0.0, 1.0, 2.0).exists(x => math.abs(x - variancePower) < 1e-8)) {
+ 0.0
+ } else {
+ summary.aic
+ }
val rNumIterations: Int = summary.numIterations
new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion,