From 92f66331b4ba3634f54f57ddb5e7962b14aa4ca1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 26 Apr 2016 10:30:24 -0700 Subject: [SPARK-14313][ML][SPARKR] AFTSurvivalRegression model persistence in SparkR ## What changes were proposed in this pull request? ```AFTSurvivalRegressionModel``` supports ```save/load``` in SparkR. ## How was this patch tested? Unit tests. Author: Yanbo Liang Closes #12685 from yanboliang/spark-14313. --- R/pkg/R/mllib.R | 27 +++++++++++ R/pkg/inst/tests/testthat/test_mllib.R | 13 ++++++ .../spark/ml/r/AFTSurvivalRegressionWrapper.scala | 52 ++++++++++++++++++++-- .../scala/org/apache/spark/ml/r/RWrappers.scala | 2 + 4 files changed, 91 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index cda6100e79..480301192d 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -364,6 +364,31 @@ setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"), invisible(callJMethod(writer, "save", path)) }) +#' Save the AFT survival regression model to the input path. +#' +#' @param object A fitted AFT survival regression model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname ml.save +#' @name ml.save +#' @export +#' @examples +#' \dontrun{ +#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData) +#' path <- "path/to/model" +#' ml.save(model, path) +#' } +setMethod("ml.save", signature(object = "AFTSurvivalRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + writer <- callJMethod(object@jobj, "write") + if (overwrite) { + writer <- callJMethod(writer, "overwrite") + } + invisible(callJMethod(writer, "save", path)) + }) + #' Load a fitted MLlib model from the input path. #' #' @param path Path of the model to read. @@ -381,6 +406,8 @@ ml.load <- function(path) { jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path) if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) { return(new("NaiveBayesModel", jobj = jobj)) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) { + return(new("AFTSurvivalRegressionModel", jobj = jobj)) } else { stop(paste("Unsupported model: ", jobj)) } diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 63ec84e497..954abb00d4 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -261,6 +261,19 @@ test_that("survreg", { expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035, 2.390146, 2.891269, 2.891269), tolerance = 1e-4) + # Test model save/load + modelPath <- tempfile(pattern = "survreg", fileext = ".tmp") + ml.save(model, modelPath) + expect_error(ml.save(model, modelPath)) + ml.save(model, modelPath, overwrite = TRUE) + model2 <- ml.load(modelPath) + stats2 <- summary(model2) + coefs2 <- as.vector(stats2$coefficients[, 1]) + expect_equal(coefs, coefs2) + expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) + + unlink(modelPath) + # Test survival::survreg if (requireNamespace("survival", quietly = TRUE)) { rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0), diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 7835468626..a442469e4d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -17,16 +17,23 @@ package org.apache.spark.ml.r +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.SparkException import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel} +import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} private[r] class AFTSurvivalRegressionWrapper private ( - pipeline: PipelineModel, - features: Array[String]) { + val pipeline: PipelineModel, + val features: Array[String]) extends MLWritable { private val aftModel: AFTSurvivalRegressionModel = pipeline.stages(1).asInstanceOf[AFTSurvivalRegressionModel] @@ -46,9 +53,12 @@ private[r] class AFTSurvivalRegressionWrapper private ( def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset).drop(aftModel.getFeaturesCol) } + + override def write: MLWriter = + new AFTSurvivalRegressionWrapper.AFTSurvivalRegressionWrapperWriter(this) } -private[r] object AFTSurvivalRegressionWrapper { +private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalRegressionWrapper] { private def formulaRewrite(formula: String): (String, String) = { var rewritedFormula: String = null @@ -96,4 +106,40 @@ private[r] object AFTSurvivalRegressionWrapper { new AFTSurvivalRegressionWrapper(pipeline, features) } + + override def read: MLReader[AFTSurvivalRegressionWrapper] = new AFTSurvivalRegressionWrapperReader + + override def load(path: String): AFTSurvivalRegressionWrapper = super.load(path) + + class AFTSurvivalRegressionWrapperWriter(instance: AFTSurvivalRegressionWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + + instance.pipeline.save(pipelinePath) + } + } + + class AFTSurvivalRegressionWrapperReader extends MLReader[AFTSurvivalRegressionWrapper] { + + override def load(path: String): AFTSurvivalRegressionWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val features = (rMetadata \ "features").extract[Array[String]] + + val pipeline = PipelineModel.load(pipelinePath) + new AFTSurvivalRegressionWrapper(pipeline, features) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 7f6f147532..06baedf2a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -38,6 +38,8 @@ private[r] object RWrappers extends MLReader[Object] { val className = (rMetadata \ "class").extract[String] className match { case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path) + case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" => + AFTSurvivalRegressionWrapper.load(path) case _ => throw new SparkException(s"SparkR ml.load does not support load $className") } -- cgit v1.2.3