diff options
author | Eric Liang <ekl@databricks.com> | 2015-07-30 16:15:43 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-07-30 16:15:43 -0700 |
commit | e7905a9395c1a002f50bab29e16a729e14d4ed6f (patch) | |
tree | 37758d36fd51f330ca7b4ce2b9f9bb47784a2dcb /R/pkg | |
parent | be7be6d4c7d978c20e601d1f5f56ecb3479814cb (diff) | |
download | spark-e7905a9395c1a002f50bab29e16a729e14d4ed6f.tar.gz spark-e7905a9395c1a002f50bab29e16a729e14d4ed6f.tar.bz2 spark-e7905a9395c1a002f50bab29e16a729e14d4ed6f.zip |
[SPARK-9463] [ML] Expose model coefficients with names in SparkR RFormula
Preview:
```
> summary(m)
features coefficients
1 (Intercept) 1.6765001
2 Sepal_Length 0.3498801
3 Species.versicolor -0.9833885
4 Species.virginica -1.0075104
```
Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit
cc mengxr
Author: Eric Liang <ekl@databricks.com>
Closes #7771 from ericl/summary and squashes the following commits:
ccd54c3 [Eric Liang] second pass
a5ca93b [Eric Liang] comments
2772111 [Eric Liang] clean up
70483ef [Eric Liang] fix test
7c247d4 [Eric Liang] Merge branch 'master' into summary
3c55024 [Eric Liang] working
8c539aa [Eric Liang] first pass
Diffstat (limited to 'R/pkg')
-rw-r--r-- | R/pkg/NAMESPACE | 3 | ||||
-rw-r--r-- | R/pkg/R/mllib.R | 26 | ||||
-rw-r--r-- | R/pkg/inst/tests/test_mllib.R | 11 |
3 files changed, 39 insertions, 1 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7f7a8a2e4d..a329e14f25 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -12,7 +12,8 @@ export("print.jobj") # MLlib integration exportMethods("glm", - "predict") + "predict", + "summary") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 6a8bacaa55..efddcc1d8d 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"), function(object, newData) { return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) }) + +#' Get the summary of a model +#' +#' Returns the summary of a model produced by glm(), similarly to R's summary(). +#' +#' @param model A fitted MLlib model +#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See +#' summary.glm for more information. +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' summary(model) +#'} +setMethod("summary", signature(object = "PipelineModel"), + function(object) { + features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelFeatures", object@model) + weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelWeights", object@model) + coefficients <- as.matrix(unlist(weights)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 3bef693247..f272de78ad 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -48,3 +48,14 @@ test_that("dot minus and intercept vs native glm", { rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) + +test_that("summary coefficients match with native glm", { + training <- createDataFrame(sqlContext, iris) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + coefs <- as.vector(stats$coefficients) + rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) + expect_true(all(abs(rCoefs - coefs) < 1e-6)) + expect_true(all( + as.character(stats$features) == + c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) +}) |