diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-11-09 08:56:22 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-11-09 08:56:22 -0800 |
commit | 8c0e1b50e960d3e8e51d0618c462eed2bb4936f0 (patch) | |
tree | 467a738e59a86e39c1f59f00b1c0bbfffba55e1c /R | |
parent | b541b31630b1b85b48d6096079d073ccf46a62e8 (diff) | |
download | spark-8c0e1b50e960d3e8e51d0618c462eed2bb4936f0.tar.gz spark-8c0e1b50e960d3e8e51d0618c462eed2bb4936f0.tar.bz2 spark-8c0e1b50e960d3e8e51d0618c462eed2bb4936f0.zip |
[SPARK-11494][ML][R] Expose R-like summary statistics in SparkR::glm for linear regression
Expose R-like summary statistics in SparkR::glm for linear regression, the output of ```summary``` like
```Java
$DevianceResiduals
Min Max
-0.9509607 0.7291832
$Coefficients
Estimate Std. Error t value Pr(>|t|)
(Intercept) 1.6765 0.2353597 7.123139 4.456124e-11
Sepal_Length 0.3498801 0.04630128 7.556598 4.187317e-12
Species_versicolor -0.9833885 0.07207471 -13.64402 0
Species_virginica -1.00751 0.09330565 -10.79796 0
```
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #9561 from yanboliang/spark-11494.
Diffstat (limited to 'R')
-rw-r--r-- | R/pkg/R/mllib.R | 22 | ||||
-rw-r--r-- | R/pkg/inst/tests/test_mllib.R | 31 |
2 files changed, 42 insertions, 11 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b0d73dd93a..7ff859741b 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -91,12 +91,26 @@ setMethod("predict", signature(object = "PipelineModel"), #'} setMethod("summary", signature(x = "PipelineModel"), function(x, ...) { + modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelName", x@model) features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "getModelFeatures", x@model) coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "getModelCoefficients", x@model) - coefficients <- as.matrix(unlist(coefficients)) - colnames(coefficients) <- c("Estimate") - rownames(coefficients) <- unlist(features) - return(list(coefficients = coefficients)) + if (modelName == "LinearRegressionModel") { + devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelDevianceResiduals", x@model) + devianceResiduals <- matrix(devianceResiduals, nrow = 1) + colnames(devianceResiduals) <- c("Min", "Max") + rownames(devianceResiduals) <- rep("", times = 1) + coefficients <- matrix(coefficients, ncol = 4) + colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") + rownames(coefficients) <- unlist(features) + return(list(DevianceResiduals = devianceResiduals, Coefficients = coefficients)) + } else { + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + } }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 4761e285a2..2606407bdc 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -71,12 +71,23 @@ test_that("feature interaction vs native glm", { test_that("summary coefficients match with native glm", { training <- createDataFrame(sqlContext, iris) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs")) - coefs <- as.vector(stats$coefficients) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal")) + coefs <- unlist(stats$Coefficients) + devianceResiduals <- unlist(stats$DevianceResiduals) + rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) - expect_true(all(abs(rCoefs - coefs) < 1e-6)) + rStdError <- c(0.23536, 0.04630, 0.07207, 0.09331) + rTValue <- c(7.123, 7.557, -13.644, -10.798) + rPValue <- c(0.0, 0.0, 0.0, 0.0) + rDevianceResiduals <- c(-0.95096, 0.72918) + + expect_true(all(abs(rCoefs - coefs[1:4]) < 1e-6)) + expect_true(all(abs(rStdError - coefs[5:8]) < 1e-5)) + expect_true(all(abs(rTValue - coefs[9:12]) < 1e-3)) + expect_true(all(abs(rPValue - coefs[13:16]) < 1e-6)) + expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5)) expect_true(all( - as.character(stats$features) == + rownames(stats$Coefficients) == c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) }) @@ -85,14 +96,20 @@ test_that("summary coefficients match with native glm of family 'binomial'", { training <- filter(df, df$Species != "setosa") stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial")) - coefs <- as.vector(stats$coefficients) + coefs <- as.vector(stats$Coefficients) rTraining <- iris[iris$Species %in% c("versicolor","virginica"),] rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, family = binomial(link = "logit")))) + rStdError <- c(3.0974, 0.5169, 0.8628) + rTValue <- c(-4.212, 3.680, 0.469) + rPValue <- c(0.000, 0.000, 0.639) - expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all(abs(rCoefs - coefs[1:3]) < 1e-4)) + expect_true(all(abs(rStdError - coefs[4:6]) < 1e-4)) + expect_true(all(abs(rTValue - coefs[7:9]) < 1e-3)) + expect_true(all(abs(rPValue - coefs[10:12]) < 1e-3)) expect_true(all( - as.character(stats$features) == + rownames(stats$Coefficients) == c("(Intercept)", "Sepal_Length", "Sepal_Width"))) }) |