aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-08-10 10:53:48 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-08-10 10:53:48 -0700
commitd4a9122430d6c3aeaaee32aa09d314016ff6ddc7 (patch)
tree19e191c481ca385c3fa93b62b5b573c44e5b637c /R
parent19af298bb6d264adcf02f6f84c8dc1542b408507 (diff)
downloadspark-d4a9122430d6c3aeaaee32aa09d314016ff6ddc7.tar.gz
spark-d4a9122430d6c3aeaaee32aa09d314016ff6ddc7.tar.bz2
spark-d4a9122430d6c3aeaaee32aa09d314016ff6ddc7.zip
[SPARK-16710][SPARKR][ML] spark.glm should support weightCol
## What changes were proposed in this pull request? Training GLMs on weighted dataset is very important use cases, but it is not supported by SparkR currently. Users can pass argument ```weights``` to specify the weights vector in native R. For ```spark.glm```, we can pass in the ```weightCol``` which is consistent with MLlib. ## How was this patch tested? Unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #14346 from yanboliang/spark-16710.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/R/mllib.R15
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R22
2 files changed, 33 insertions, 4 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 50c601fcd9..25d9f077b4 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -91,6 +91,8 @@ NULL
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
#' @param tol Positive convergence tolerance of iterations.
#' @param maxIter Integer giving the maximal number of IRLS iterations.
+#' @param weightCol The weight column name. If this is not set or NULL, we treat all instance
+#' weights as 1.0.
#' @aliases spark.glm,SparkDataFrame,formula-method
#' @return \code{spark.glm} returns a fitted generalized linear model
#' @rdname spark.glm
@@ -119,7 +121,7 @@ NULL
#' @note spark.glm since 2.0.0
#' @seealso \link{glm}, \link{read.ml}
setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
- function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25) {
+ function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL) {
if (is.character(family)) {
family <- get(family, mode = "function", envir = parent.frame())
}
@@ -132,10 +134,13 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
}
formula <- paste(deparse(formula), collapse = "")
+ if (is.null(weightCol)) {
+ weightCol <- ""
+ }
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, family$family, family$link,
- tol, as.integer(maxIter))
+ tol, as.integer(maxIter), weightCol)
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
})
@@ -151,6 +156,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
#' @param epsilon Positive convergence tolerance of iterations.
#' @param maxit Integer giving the maximal number of IRLS iterations.
+#' @param weightCol The weight column name. If this is not set or NULL, we treat all instance
+#' weights as 1.0.
#' @return \code{glm} returns a fitted generalized linear model.
#' @rdname glm
#' @export
@@ -165,8 +172,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
#' @note glm since 1.5.0
#' @seealso \link{spark.glm}
setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"),
- function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25) {
- spark.glm(data, formula, family, tol = epsilon, maxIter = maxit)
+ function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL) {
+ spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol)
})
# Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary().
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index ab390a86d1..bc18224680 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -118,6 +118,28 @@ test_that("spark.glm summary", {
expect_equal(stats$df.residual, rStats$df.residual)
expect_equal(stats$aic, rStats$aic)
+ # Test spark.glm works with weighted dataset
+ a1 <- c(0, 1, 2, 3)
+ a2 <- c(5, 2, 1, 3)
+ w <- c(1, 2, 3, 4)
+ b <- c(1, 0, 1, 0)
+ data <- as.data.frame(cbind(a1, a2, w, b))
+ df <- suppressWarnings(createDataFrame(data))
+
+ stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w"))
+ rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w))
+
+ coefs <- unlist(stats$coefficients)
+ rCoefs <- unlist(rStats$coefficients)
+ expect_true(all(abs(rCoefs - coefs) < 1e-3))
+ expect_true(all(rownames(stats$coefficients) == c("(Intercept)", "a1", "a2")))
+ expect_equal(stats$dispersion, rStats$dispersion)
+ expect_equal(stats$null.deviance, rStats$null.deviance)
+ expect_equal(stats$deviance, rStats$deviance)
+ expect_equal(stats$df.null, rStats$df.null)
+ expect_equal(stats$df.residual, rStats$df.residual)
+ expect_equal(stats$aic, rStats$aic)
+
# Test summary works on base GLM models
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
baseSummary <- summary(baseModel)