diff options
Diffstat (limited to 'R/pkg/R/mllib_regression.R')
-rw-r--r-- | R/pkg/R/mllib_regression.R | 55 |
1 files changed, 47 insertions, 8 deletions
diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 648d363f1a..d59c890f3e 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -53,12 +53,23 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. #' Currently these families are supported: \code{binomial}, \code{gaussian}, -#' \code{Gamma}, and \code{poisson}. +#' \code{Gamma}, \code{poisson} and \code{tweedie}. +#' +#' Note that there are two ways to specify the tweedie family. +#' \itemize{ +#' \item Set \code{family = "tweedie"} and specify the var.power and link.power; +#' \item When package \code{statmod} is loaded, the tweedie family is specified using the +#' family definition therein, i.e., \code{tweedie(var.power, link.power)}. +#' } #' @param tol positive convergence tolerance of iterations. #' @param maxIter integer giving the maximal number of IRLS iterations. #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance #' weights as 1.0. #' @param regParam regularization parameter for L2 regularization. +#' @param var.power the power in the variance function of the Tweedie distribution which provides +#' the relationship between the variance and mean of the distribution. Only +#' applicable to the Tweedie family. +#' @param link.power the index in the power link function. Only applicable to the Tweedie family. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model. @@ -84,14 +95,30 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' # can also read back the saved model and print #' savedModel <- read.ml(path) #' summary(savedModel) +#' +#' # fit tweedie model +#' model <- spark.glm(df, Freq ~ Sex + Age, family = "tweedie", +#' var.power = 1.2, link.power = 0) +#' summary(model) +#' +#' # use the tweedie family from statmod +#' library(statmod) +#' model <- spark.glm(df, Freq ~ Sex + Age, family = tweedie(1.2, 0)) +#' summary(model) #' } #' @note spark.glm since 2.0.0 #' @seealso \link{glm}, \link{read.ml} setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL, - regParam = 0.0) { + regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power) { + if (is.character(family)) { - family <- get(family, mode = "function", envir = parent.frame()) + # Handle when family = "tweedie" + if (tolower(family) == "tweedie") { + family <- list(family = "tweedie", link = NULL) + } else { + family <- get(family, mode = "function", envir = parent.frame()) + } } if (is.function(family)) { family <- family() @@ -100,6 +127,12 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), print(family) stop("'family' not recognized") } + # Handle when family = statmod::tweedie() + if (tolower(family$family) == "tweedie" && !is.null(family$variance)) { + var.power <- log(family$variance(exp(1))) + link.power <- log(family$linkfun(exp(1))) + family <- list(family = "tweedie", link = NULL) + } formula <- paste(deparse(formula), collapse = "") if (!is.null(weightCol) && weightCol == "") { @@ -111,7 +144,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), # For known families, Gamma is upper-cased jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", "fit", formula, data@sdf, tolower(family$family), family$link, - tol, as.integer(maxIter), weightCol, regParam) + tol, as.integer(maxIter), weightCol, regParam, + as.double(var.power), as.double(link.power)) new("GeneralizedLinearRegressionModel", jobj = jobj) }) @@ -126,11 +160,13 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. #' Currently these families are supported: \code{binomial}, \code{gaussian}, -#' \code{Gamma}, and \code{poisson}. +#' \code{poisson}, \code{Gamma}, and \code{tweedie}. #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance #' weights as 1.0. #' @param epsilon positive convergence tolerance of iterations. #' @param maxit integer giving the maximal number of IRLS iterations. +#' @param var.power the index of the power variance function in the Tweedie family. +#' @param link.power the index of the power link function in the Tweedie family. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export @@ -145,8 +181,10 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @note glm since 1.5.0 #' @seealso \link{spark.glm} setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), - function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL) { - spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol) + function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL, + var.power = 0.0, link.power = 1.0 - var.power) { + spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol, + var.power = var.power, link.power = link.power) }) # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). @@ -172,9 +210,10 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), deviance <- callJMethod(jobj, "rDeviance") df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull") df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom") - aic <- callJMethod(jobj, "rAic") iter <- callJMethod(jobj, "rNumIterations") family <- callJMethod(jobj, "rFamily") + aic <- callJMethod(jobj, "rAic") + if (family == "tweedie" && aic == 0) aic <- NA deviance.resid <- if (is.loaded) { NULL } else { |