diff options
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala | 52 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala | 2 |
2 files changed, 51 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 7835468626..a442469e4d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -17,16 +17,23 @@ package org.apache.spark.ml.r +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.SparkException import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel} +import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} private[r] class AFTSurvivalRegressionWrapper private ( - pipeline: PipelineModel, - features: Array[String]) { + val pipeline: PipelineModel, + val features: Array[String]) extends MLWritable { private val aftModel: AFTSurvivalRegressionModel = pipeline.stages(1).asInstanceOf[AFTSurvivalRegressionModel] @@ -46,9 +53,12 @@ private[r] class AFTSurvivalRegressionWrapper private ( def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset).drop(aftModel.getFeaturesCol) } + + override def write: MLWriter = + new AFTSurvivalRegressionWrapper.AFTSurvivalRegressionWrapperWriter(this) } -private[r] object AFTSurvivalRegressionWrapper { +private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalRegressionWrapper] { private def formulaRewrite(formula: String): (String, String) = { var rewritedFormula: String = null @@ -96,4 +106,40 @@ private[r] object AFTSurvivalRegressionWrapper { new AFTSurvivalRegressionWrapper(pipeline, features) } + + override def read: MLReader[AFTSurvivalRegressionWrapper] = new AFTSurvivalRegressionWrapperReader + + override def load(path: String): AFTSurvivalRegressionWrapper = super.load(path) + + class AFTSurvivalRegressionWrapperWriter(instance: AFTSurvivalRegressionWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + + instance.pipeline.save(pipelinePath) + } + } + + class AFTSurvivalRegressionWrapperReader extends MLReader[AFTSurvivalRegressionWrapper] { + + override def load(path: String): AFTSurvivalRegressionWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val features = (rMetadata \ "features").extract[Array[String]] + + val pipeline = PipelineModel.load(pipelinePath) + new AFTSurvivalRegressionWrapper(pipeline, features) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 7f6f147532..06baedf2a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -38,6 +38,8 @@ private[r] object RWrappers extends MLReader[Object] { val className = (rMetadata \ "class").extract[String] className match { case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path) + case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" => + AFTSurvivalRegressionWrapper.load(path) case _ => throw new SparkException(s"SparkR ml.load does not support load $className") } |