aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala52
1 files changed, 49 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
index 7835468626..a442469e4d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -17,16 +17,23 @@
package org.apache.spark.ml.r
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.SparkException
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel}
+import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
private[r] class AFTSurvivalRegressionWrapper private (
- pipeline: PipelineModel,
- features: Array[String]) {
+ val pipeline: PipelineModel,
+ val features: Array[String]) extends MLWritable {
private val aftModel: AFTSurvivalRegressionModel =
pipeline.stages(1).asInstanceOf[AFTSurvivalRegressionModel]
@@ -46,9 +53,12 @@ private[r] class AFTSurvivalRegressionWrapper private (
def transform(dataset: Dataset[_]): DataFrame = {
pipeline.transform(dataset).drop(aftModel.getFeaturesCol)
}
+
+ override def write: MLWriter =
+ new AFTSurvivalRegressionWrapper.AFTSurvivalRegressionWrapperWriter(this)
}
-private[r] object AFTSurvivalRegressionWrapper {
+private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalRegressionWrapper] {
private def formulaRewrite(formula: String): (String, String) = {
var rewritedFormula: String = null
@@ -96,4 +106,40 @@ private[r] object AFTSurvivalRegressionWrapper {
new AFTSurvivalRegressionWrapper(pipeline, features)
}
+
+ override def read: MLReader[AFTSurvivalRegressionWrapper] = new AFTSurvivalRegressionWrapperReader
+
+ override def load(path: String): AFTSurvivalRegressionWrapper = super.load(path)
+
+ class AFTSurvivalRegressionWrapperWriter(instance: AFTSurvivalRegressionWrapper)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("features" -> instance.features.toSeq)
+ val rMetadataJson: String = compact(render(rMetadata))
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class AFTSurvivalRegressionWrapperReader extends MLReader[AFTSurvivalRegressionWrapper] {
+
+ override def load(path: String): AFTSurvivalRegressionWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val features = (rMetadata \ "features").extract[Array[String]]
+
+ val pipeline = PipelineModel.load(pipelinePath)
+ new AFTSurvivalRegressionWrapper(pipeline, features)
+ }
+ }
}