From c4e19b3819df4cd7a1c495a00bd2844cf55f4dbd Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 9 Nov 2015 21:06:01 -0800 Subject: [SPARK-11587][SPARKR] Fix the summary generic to match base R The signature is summary(object, ...) as defined in https://stat.ethz.ch/R-manual/R-devel/library/base/html/summary.html Author: Shivaram Venkataraman Closes #9582 from shivaram/summary-fix. --- R/pkg/R/DataFrame.R | 6 +++--- R/pkg/R/generics.R | 2 +- R/pkg/R/mllib.R | 12 ++++++------ R/pkg/inst/tests/test_mllib.R | 6 ++++++ 4 files changed, 16 insertions(+), 10 deletions(-) (limited to 'R') diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 44ce9414da..e9013aa34a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1944,9 +1944,9 @@ setMethod("describe", #' @rdname summary #' @name summary setMethod("summary", - signature(x = "DataFrame"), - function(x) { - describe(x) + signature(object = "DataFrame"), + function(object, ...) { + describe(object) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 083d37fee2..efef7d66b5 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -561,7 +561,7 @@ setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) #' @rdname summary #' @export -setGeneric("summary", function(x, ...) { standardGeneric("summary") }) +setGeneric("summary", function(object, ...) { standardGeneric("summary") }) # @rdname tojson # @export diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 7ff859741b..7126b7cde4 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -89,17 +89,17 @@ setMethod("predict", signature(object = "PipelineModel"), #' model <- glm(y ~ x, trainingData) #' summary(model) #'} -setMethod("summary", signature(x = "PipelineModel"), - function(x, ...) { +setMethod("summary", signature(object = "PipelineModel"), + function(object, ...) { modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelName", x@model) + "getModelName", object@model) features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelFeatures", x@model) + "getModelFeatures", object@model) coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelCoefficients", x@model) + "getModelCoefficients", object@model) if (modelName == "LinearRegressionModel") { devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelDevianceResiduals", x@model) + "getModelDevianceResiduals", object@model) devianceResiduals <- matrix(devianceResiduals, nrow = 1) colnames(devianceResiduals) <- c("Min", "Max") rownames(devianceResiduals) <- rep("", times = 1) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 2606407bdc..42287ea19a 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -113,3 +113,9 @@ test_that("summary coefficients match with native glm of family 'binomial'", { rownames(stats$Coefficients) == c("(Intercept)", "Sepal_Length", "Sepal_Width"))) }) + +test_that("summary works on base GLM models", { + baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) + baseSummary <- summary(baseModel) + expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) +}) -- cgit v1.2.3