aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-09 08:56:22 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-09 08:56:22 -0800
commit8c0e1b50e960d3e8e51d0618c462eed2bb4936f0 (patch)
tree467a738e59a86e39c1f59f00b1c0bbfffba55e1c /R
parentb541b31630b1b85b48d6096079d073ccf46a62e8 (diff)
downloadspark-8c0e1b50e960d3e8e51d0618c462eed2bb4936f0.tar.gz
spark-8c0e1b50e960d3e8e51d0618c462eed2bb4936f0.tar.bz2
spark-8c0e1b50e960d3e8e51d0618c462eed2bb4936f0.zip
[SPARK-11494][ML][R] Expose R-like summary statistics in SparkR::glm for linear regression
Expose R-like summary statistics in SparkR::glm for linear regression, the output of ```summary``` like ```Java $DevianceResiduals Min Max -0.9509607 0.7291832 $Coefficients Estimate Std. Error t value Pr(>|t|) (Intercept) 1.6765 0.2353597 7.123139 4.456124e-11 Sepal_Length 0.3498801 0.04630128 7.556598 4.187317e-12 Species_versicolor -0.9833885 0.07207471 -13.64402 0 Species_virginica -1.00751 0.09330565 -10.79796 0 ``` Author: Yanbo Liang <ybliang8@gmail.com> Closes #9561 from yanboliang/spark-11494.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/R/mllib.R22
-rw-r--r--R/pkg/inst/tests/test_mllib.R31
2 files changed, 42 insertions, 11 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index b0d73dd93a..7ff859741b 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -91,12 +91,26 @@ setMethod("predict", signature(object = "PipelineModel"),
#'}
setMethod("summary", signature(x = "PipelineModel"),
function(x, ...) {
+ modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getModelName", x@model)
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelFeatures", x@model)
coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelCoefficients", x@model)
- coefficients <- as.matrix(unlist(coefficients))
- colnames(coefficients) <- c("Estimate")
- rownames(coefficients) <- unlist(features)
- return(list(coefficients = coefficients))
+ if (modelName == "LinearRegressionModel") {
+ devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getModelDevianceResiduals", x@model)
+ devianceResiduals <- matrix(devianceResiduals, nrow = 1)
+ colnames(devianceResiduals) <- c("Min", "Max")
+ rownames(devianceResiduals) <- rep("", times = 1)
+ coefficients <- matrix(coefficients, ncol = 4)
+ colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
+ rownames(coefficients) <- unlist(features)
+ return(list(DevianceResiduals = devianceResiduals, Coefficients = coefficients))
+ } else {
+ coefficients <- as.matrix(unlist(coefficients))
+ colnames(coefficients) <- c("Estimate")
+ rownames(coefficients) <- unlist(features)
+ return(list(coefficients = coefficients))
+ }
})
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 4761e285a2..2606407bdc 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -71,12 +71,23 @@ test_that("feature interaction vs native glm", {
test_that("summary coefficients match with native glm", {
training <- createDataFrame(sqlContext, iris)
- stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs"))
- coefs <- as.vector(stats$coefficients)
+ stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal"))
+ coefs <- unlist(stats$Coefficients)
+ devianceResiduals <- unlist(stats$DevianceResiduals)
+
rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
- expect_true(all(abs(rCoefs - coefs) < 1e-6))
+ rStdError <- c(0.23536, 0.04630, 0.07207, 0.09331)
+ rTValue <- c(7.123, 7.557, -13.644, -10.798)
+ rPValue <- c(0.0, 0.0, 0.0, 0.0)
+ rDevianceResiduals <- c(-0.95096, 0.72918)
+
+ expect_true(all(abs(rCoefs - coefs[1:4]) < 1e-6))
+ expect_true(all(abs(rStdError - coefs[5:8]) < 1e-5))
+ expect_true(all(abs(rTValue - coefs[9:12]) < 1e-3))
+ expect_true(all(abs(rPValue - coefs[13:16]) < 1e-6))
+ expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5))
expect_true(all(
- as.character(stats$features) ==
+ rownames(stats$Coefficients) ==
c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
})
@@ -85,14 +96,20 @@ test_that("summary coefficients match with native glm of family 'binomial'", {
training <- filter(df, df$Species != "setosa")
stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
family = "binomial"))
- coefs <- as.vector(stats$coefficients)
+ coefs <- as.vector(stats$Coefficients)
rTraining <- iris[iris$Species %in% c("versicolor","virginica"),]
rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
family = binomial(link = "logit"))))
+ rStdError <- c(3.0974, 0.5169, 0.8628)
+ rTValue <- c(-4.212, 3.680, 0.469)
+ rPValue <- c(0.000, 0.000, 0.639)
- expect_true(all(abs(rCoefs - coefs) < 1e-4))
+ expect_true(all(abs(rCoefs - coefs[1:3]) < 1e-4))
+ expect_true(all(abs(rStdError - coefs[4:6]) < 1e-4))
+ expect_true(all(abs(rTValue - coefs[7:9]) < 1e-3))
+ expect_true(all(abs(rPValue - coefs[10:12]) < 1e-3))
expect_true(all(
- as.character(stats$features) ==
+ rownames(stats$Coefficients) ==
c("(Intercept)", "Sepal_Length", "Sepal_Width")))
})