diff options
Diffstat (limited to 'R')
-rw-r--r-- | R/pkg/R/mllib.R | 15 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_mllib.R | 22 |
2 files changed, 33 insertions, 4 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 50c601fcd9..25d9f077b4 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -91,6 +91,8 @@ NULL #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. #' @param tol Positive convergence tolerance of iterations. #' @param maxIter Integer giving the maximal number of IRLS iterations. +#' @param weightCol The weight column name. If this is not set or NULL, we treat all instance +#' weights as 1.0. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model #' @rdname spark.glm @@ -119,7 +121,7 @@ NULL #' @note spark.glm since 2.0.0 #' @seealso \link{glm}, \link{read.ml} setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25) { + function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL) { if (is.character(family)) { family <- get(family, mode = "function", envir = parent.frame()) } @@ -132,10 +134,13 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), } formula <- paste(deparse(formula), collapse = "") + if (is.null(weightCol)) { + weightCol <- "" + } jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", "fit", formula, data@sdf, family$family, family$link, - tol, as.integer(maxIter)) + tol, as.integer(maxIter), weightCol) return(new("GeneralizedLinearRegressionModel", jobj = jobj)) }) @@ -151,6 +156,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. #' @param epsilon Positive convergence tolerance of iterations. #' @param maxit Integer giving the maximal number of IRLS iterations. +#' @param weightCol The weight column name. If this is not set or NULL, we treat all instance +#' weights as 1.0. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export @@ -165,8 +172,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @note glm since 1.5.0 #' @seealso \link{spark.glm} setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), - function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25) { - spark.glm(data, formula, family, tol = epsilon, maxIter = maxit) + function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL) { + spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol) }) # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index ab390a86d1..bc18224680 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -118,6 +118,28 @@ test_that("spark.glm summary", { expect_equal(stats$df.residual, rStats$df.residual) expect_equal(stats$aic, rStats$aic) + # Test spark.glm works with weighted dataset + a1 <- c(0, 1, 2, 3) + a2 <- c(5, 2, 1, 3) + w <- c(1, 2, 3, 4) + b <- c(1, 0, 1, 0) + data <- as.data.frame(cbind(a1, a2, w, b)) + df <- suppressWarnings(createDataFrame(data)) + + stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w")) + rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w)) + + coefs <- unlist(stats$coefficients) + rCoefs <- unlist(rStats$coefficients) + expect_true(all(abs(rCoefs - coefs) < 1e-3)) + expect_true(all(rownames(stats$coefficients) == c("(Intercept)", "a1", "a2"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + # Test summary works on base GLM models baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) baseSummary <- summary(baseModel) |