aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-03-24 22:29:34 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-24 22:29:34 -0700
commit13cbb2de709d0ec2707eebf36c5c97f7d44fb84f (patch)
treeec2d0aa7a5579c64dc1fc2703467f7a7a720a266 /R
parent05f652d6c2bbd764a1dd5a45301811e14519486f (diff)
downloadspark-13cbb2de709d0ec2707eebf36c5c97f7d44fb84f.tar.gz
spark-13cbb2de709d0ec2707eebf36c5c97f7d44fb84f.tar.bz2
spark-13cbb2de709d0ec2707eebf36c5c97f7d44fb84f.zip
[SPARK-13010][ML][SPARKR] Implement a simple wrapper of AFTSurvivalRegression in SparkR
## What changes were proposed in this pull request? This PR continues the work in #11447, we implemented the wrapper of ```AFTSurvivalRegression``` named ```survreg``` in SparkR. ## How was this patch tested? Test against output from R package survival's survreg. cc mengxr felixcheung Close #11447 Author: Yanbo Liang <ybliang8@gmail.com> Closes #11932 from yanboliang/spark-13010-new.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/DESCRIPTION3
-rw-r--r--R/pkg/NAMESPACE3
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/R/mllib.R75
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R49
5 files changed, 132 insertions, 2 deletions
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index e26f9a7a2a..7179438efc 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -12,7 +12,8 @@ Depends:
methods,
Suggests:
testthat,
- e1071
+ e1071,
+ survival
Description: R frontend for Spark
License: Apache License (== 2.0)
Collate:
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 5d8a4b1d6e..fa3fb0b09a 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -16,7 +16,8 @@ exportMethods("glm",
"summary",
"kmeans",
"fitted",
- "naiveBayes")
+ "naiveBayes",
+ "survreg")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 46b115f45e..c6990f4748 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1179,3 +1179,7 @@ setGeneric("fitted")
#' @rdname naiveBayes
#' @export
setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") })
+
+#' @rdname survreg
+#' @export
+setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 2555019369..33654d5216 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -27,6 +27,11 @@ setClass("PipelineModel", representation(model = "jobj"))
#' @export
setClass("NaiveBayesModel", representation(jobj = "jobj"))
+#' @title S4 class that represents a AFTSurvivalRegressionModel
+#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper
+#' @export
+setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))
+
#' Fits a generalized linear model
#'
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -273,3 +278,73 @@ setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"),
formula, data@sdf, laplace)
return(new("NaiveBayesModel", jobj = jobj))
})
+
+#' Fit an accelerated failure time (AFT) survival regression model.
+#'
+#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg().
+#'
+#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#' operators are supported, including '~', ':', '+', and '-'.
+#' Note that operator '.' is not supported currently.
+#' @param data DataFrame for training.
+#' @return a fitted AFT survival regression model
+#' @rdname survreg
+#' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/}
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame(sqlContext, ovarian)
+#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, df)
+#' }
+setMethod("survreg", signature(formula = "formula", data = "DataFrame"),
+ function(formula, data, ...) {
+ formula <- paste(deparse(formula), collapse = "")
+ jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
+ "fit", formula, data@sdf)
+ return(new("AFTSurvivalRegressionModel", jobj = jobj))
+ })
+
+#' Get the summary of an AFT survival regression model
+#'
+#' Returns the summary of an AFT survival regression model produced by survreg(),
+#' similarly to R's summary().
+#'
+#' @param object a fitted AFT survival regression model
+#' @return coefficients the model's coefficients, intercept and log(scale).
+#' @rdname summary
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
+#' summary(model)
+#' }
+setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
+ function(object, ...) {
+ jobj <- object@jobj
+ features <- callJMethod(jobj, "rFeatures")
+ coefficients <- callJMethod(jobj, "rCoefficients")
+ coefficients <- as.matrix(unlist(coefficients))
+ colnames(coefficients) <- c("Value")
+ rownames(coefficients) <- unlist(features)
+ return(list(coefficients = coefficients))
+ })
+
+#' Make predictions from an AFT survival regression model
+#'
+#' Make predictions from a model produced by survreg(), similarly to R package survival's predict.
+#'
+#' @param object A fitted AFT survival regression model
+#' @param newData DataFrame for testing
+#' @return DataFrame containing predicted labels in a column named "prediction"
+#' @rdname predict
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
+#' predicted <- predict(model, testData)
+#' showDF(predicted)
+#' }
+setMethod("predict", signature(object = "AFTSurvivalRegressionModel"),
+ function(object, newData) {
+ return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+ })
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 44b48369ef..fdb591756e 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -200,3 +200,52 @@ test_that("naiveBayes", {
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
}
})
+
+test_that("survreg", {
+ # R code to reproduce the result.
+ #
+ #' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
+ #' x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
+ #' library(survival)
+ #' model <- survreg(Surv(time, status) ~ x + sex, rData)
+ #' summary(model)
+ #' predict(model, data)
+ #
+ # -- output of 'summary(model)'
+ #
+ # Value Std. Error z p
+ # (Intercept) 1.315 0.270 4.88 1.07e-06
+ # x -0.190 0.173 -1.10 2.72e-01
+ # sex -0.253 0.329 -0.77 4.42e-01
+ # Log(scale) -1.160 0.396 -2.93 3.41e-03
+ #
+ # -- output of 'predict(model, data)'
+ #
+ # 1 2 3 4 5 6 7
+ # 3.724591 2.545368 3.079035 3.079035 2.390146 2.891269 2.891269
+ #
+ data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0),
+ list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1))
+ df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex"))
+ model <- survreg(Surv(time, status) ~ x + sex, df)
+ stats <- summary(model)
+ coefs <- as.vector(stats$coefficients[, 1])
+ rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800)
+ expect_equal(coefs, rCoefs, tolerance = 1e-4)
+ expect_true(all(
+ rownames(stats$coefficients) ==
+ c("(Intercept)", "x", "sex", "Log(scale)")))
+ p <- collect(select(predict(model, df), "prediction"))
+ expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035,
+ 2.390146, 2.891269, 2.891269), tolerance = 1e-4)
+
+ # Test survival::survreg
+ if (requireNamespace("survival", quietly = TRUE)) {
+ rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
+ x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
+ expect_that(
+ model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData),
+ not(throws_error()))
+ expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4)
+ }
+})