aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/R/mllib_regression.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/R/mllib_regression.R')
-rw-r--r--R/pkg/R/mllib_regression.R55
1 files changed, 47 insertions, 8 deletions
diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R
index 648d363f1a..d59c890f3e 100644
--- a/R/pkg/R/mllib_regression.R
+++ b/R/pkg/R/mllib_regression.R
@@ -53,12 +53,23 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
#' Currently these families are supported: \code{binomial}, \code{gaussian},
-#' \code{Gamma}, and \code{poisson}.
+#' \code{Gamma}, \code{poisson} and \code{tweedie}.
+#'
+#' Note that there are two ways to specify the tweedie family.
+#' \itemize{
+#' \item Set \code{family = "tweedie"} and specify the var.power and link.power;
+#' \item When package \code{statmod} is loaded, the tweedie family is specified using the
+#' family definition therein, i.e., \code{tweedie(var.power, link.power)}.
+#' }
#' @param tol positive convergence tolerance of iterations.
#' @param maxIter integer giving the maximal number of IRLS iterations.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param regParam regularization parameter for L2 regularization.
+#' @param var.power the power in the variance function of the Tweedie distribution which provides
+#' the relationship between the variance and mean of the distribution. Only
+#' applicable to the Tweedie family.
+#' @param link.power the index in the power link function. Only applicable to the Tweedie family.
#' @param ... additional arguments passed to the method.
#' @aliases spark.glm,SparkDataFrame,formula-method
#' @return \code{spark.glm} returns a fitted generalized linear model.
@@ -84,14 +95,30 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
#' # can also read back the saved model and print
#' savedModel <- read.ml(path)
#' summary(savedModel)
+#'
+#' # fit tweedie model
+#' model <- spark.glm(df, Freq ~ Sex + Age, family = "tweedie",
+#' var.power = 1.2, link.power = 0)
+#' summary(model)
+#'
+#' # use the tweedie family from statmod
+#' library(statmod)
+#' model <- spark.glm(df, Freq ~ Sex + Age, family = tweedie(1.2, 0))
+#' summary(model)
#' }
#' @note spark.glm since 2.0.0
#' @seealso \link{glm}, \link{read.ml}
setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL,
- regParam = 0.0) {
+ regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power) {
+
if (is.character(family)) {
- family <- get(family, mode = "function", envir = parent.frame())
+ # Handle when family = "tweedie"
+ if (tolower(family) == "tweedie") {
+ family <- list(family = "tweedie", link = NULL)
+ } else {
+ family <- get(family, mode = "function", envir = parent.frame())
+ }
}
if (is.function(family)) {
family <- family()
@@ -100,6 +127,12 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
print(family)
stop("'family' not recognized")
}
+ # Handle when family = statmod::tweedie()
+ if (tolower(family$family) == "tweedie" && !is.null(family$variance)) {
+ var.power <- log(family$variance(exp(1)))
+ link.power <- log(family$linkfun(exp(1)))
+ family <- list(family = "tweedie", link = NULL)
+ }
formula <- paste(deparse(formula), collapse = "")
if (!is.null(weightCol) && weightCol == "") {
@@ -111,7 +144,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
# For known families, Gamma is upper-cased
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, tolower(family$family), family$link,
- tol, as.integer(maxIter), weightCol, regParam)
+ tol, as.integer(maxIter), weightCol, regParam,
+ as.double(var.power), as.double(link.power))
new("GeneralizedLinearRegressionModel", jobj = jobj)
})
@@ -126,11 +160,13 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
#' Currently these families are supported: \code{binomial}, \code{gaussian},
-#' \code{Gamma}, and \code{poisson}.
+#' \code{poisson}, \code{Gamma}, and \code{tweedie}.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param epsilon positive convergence tolerance of iterations.
#' @param maxit integer giving the maximal number of IRLS iterations.
+#' @param var.power the index of the power variance function in the Tweedie family.
+#' @param link.power the index of the power link function in the Tweedie family.
#' @return \code{glm} returns a fitted generalized linear model.
#' @rdname glm
#' @export
@@ -145,8 +181,10 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
#' @note glm since 1.5.0
#' @seealso \link{spark.glm}
setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"),
- function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL) {
- spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol)
+ function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL,
+ var.power = 0.0, link.power = 1.0 - var.power) {
+ spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol,
+ var.power = var.power, link.power = link.power)
})
# Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary().
@@ -172,9 +210,10 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
deviance <- callJMethod(jobj, "rDeviance")
df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull")
df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom")
- aic <- callJMethod(jobj, "rAic")
iter <- callJMethod(jobj, "rNumIterations")
family <- callJMethod(jobj, "rFamily")
+ aic <- callJMethod(jobj, "rAic")
+ if (family == "tweedie" && aic == 0) aic <- NA
deviance.resid <- if (is.loaded) {
NULL
} else {