aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2015-07-30 16:15:43 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-30 16:15:43 -0700
commite7905a9395c1a002f50bab29e16a729e14d4ed6f (patch)
tree37758d36fd51f330ca7b4ce2b9f9bb47784a2dcb /R/pkg
parentbe7be6d4c7d978c20e601d1f5f56ecb3479814cb (diff)
downloadspark-e7905a9395c1a002f50bab29e16a729e14d4ed6f.tar.gz
spark-e7905a9395c1a002f50bab29e16a729e14d4ed6f.tar.bz2
spark-e7905a9395c1a002f50bab29e16a729e14d4ed6f.zip
[SPARK-9463] [ML] Expose model coefficients with names in SparkR RFormula
Preview: ``` > summary(m) features coefficients 1 (Intercept) 1.6765001 2 Sepal_Length 0.3498801 3 Species.versicolor -0.9833885 4 Species.virginica -1.0075104 ``` Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit cc mengxr Author: Eric Liang <ekl@databricks.com> Closes #7771 from ericl/summary and squashes the following commits: ccd54c3 [Eric Liang] second pass a5ca93b [Eric Liang] comments 2772111 [Eric Liang] clean up 70483ef [Eric Liang] fix test 7c247d4 [Eric Liang] Merge branch 'master' into summary 3c55024 [Eric Liang] working 8c539aa [Eric Liang] first pass
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/NAMESPACE3
-rw-r--r--R/pkg/R/mllib.R26
-rw-r--r--R/pkg/inst/tests/test_mllib.R11
3 files changed, 39 insertions, 1 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 7f7a8a2e4d..a329e14f25 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -12,7 +12,8 @@ export("print.jobj")
# MLlib integration
exportMethods("glm",
- "predict")
+ "predict",
+ "summary")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 6a8bacaa55..efddcc1d8d 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
})
+
+#' Get the summary of a model
+#'
+#' Returns the summary of a model produced by glm(), similarly to R's summary().
+#'
+#' @param model A fitted MLlib model
+#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See
+#' summary.glm for more information.
+#' @rdname glm
+#' @export
+#' @examples
+#'\dontrun{
+#' model <- glm(y ~ x, trainingData)
+#' summary(model)
+#'}
+setMethod("summary", signature(object = "PipelineModel"),
+ function(object) {
+ features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getModelFeatures", object@model)
+ weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getModelWeights", object@model)
+ coefficients <- as.matrix(unlist(weights))
+ colnames(coefficients) <- c("Estimate")
+ rownames(coefficients) <- unlist(features)
+ return(list(coefficients = coefficients))
+ })
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 3bef693247..f272de78ad 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -48,3 +48,14 @@ test_that("dot minus and intercept vs native glm", {
rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})
+
+test_that("summary coefficients match with native glm", {
+ training <- createDataFrame(sqlContext, iris)
+ stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
+ coefs <- as.vector(stats$coefficients)
+ rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
+ expect_true(all(abs(rCoefs - coefs) < 1e-6))
+ expect_true(all(
+ as.character(stats$features) ==
+ c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
+})