From d4a9122430d6c3aeaaee32aa09d314016ff6ddc7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 10 Aug 2016 10:53:48 -0700 Subject: [SPARK-16710][SPARKR][ML] spark.glm should support weightCol ## What changes were proposed in this pull request? Training GLMs on weighted dataset is very important use cases, but it is not supported by SparkR currently. Users can pass argument ```weights``` to specify the weights vector in native R. For ```spark.glm```, we can pass in the ```weightCol``` which is consistent with MLlib. ## How was this patch tested? Unit test. Author: Yanbo Liang Closes #14346 from yanboliang/spark-16710. --- R/pkg/inst/tests/testthat/test_mllib.R | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) (limited to 'R/pkg/inst') diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index ab390a86d1..bc18224680 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -118,6 +118,28 @@ test_that("spark.glm summary", { expect_equal(stats$df.residual, rStats$df.residual) expect_equal(stats$aic, rStats$aic) + # Test spark.glm works with weighted dataset + a1 <- c(0, 1, 2, 3) + a2 <- c(5, 2, 1, 3) + w <- c(1, 2, 3, 4) + b <- c(1, 0, 1, 0) + data <- as.data.frame(cbind(a1, a2, w, b)) + df <- suppressWarnings(createDataFrame(data)) + + stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w")) + rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w)) + + coefs <- unlist(stats$coefficients) + rCoefs <- unlist(rStats$coefficients) + expect_true(all(abs(rCoefs - coefs) < 1e-3)) + expect_true(all(rownames(stats$coefficients) == c("(Intercept)", "a1", "a2"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + # Test summary works on base GLM models baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) baseSummary <- summary(baseModel) -- cgit v1.2.3