aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-03-24 22:29:34 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-24 22:29:34 -0700
commit13cbb2de709d0ec2707eebf36c5c97f7d44fb84f (patch)
treeec2d0aa7a5579c64dc1fc2703467f7a7a720a266 /mllib/src/main
parent05f652d6c2bbd764a1dd5a45301811e14519486f (diff)
downloadspark-13cbb2de709d0ec2707eebf36c5c97f7d44fb84f.tar.gz
spark-13cbb2de709d0ec2707eebf36c5c97f7d44fb84f.tar.bz2
spark-13cbb2de709d0ec2707eebf36c5c97f7d44fb84f.zip
[SPARK-13010][ML][SPARKR] Implement a simple wrapper of AFTSurvivalRegression in SparkR
## What changes were proposed in this pull request? This PR continues the work in #11447, we implemented the wrapper of ```AFTSurvivalRegression``` named ```survreg``` in SparkR. ## How was this patch tested? Test against output from R package survival's survreg. cc mengxr felixcheung Close #11447 Author: Yanbo Liang <ybliang8@gmail.com> Closes #11932 from yanboliang/spark-13010-new.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala99
1 files changed, 99 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
new file mode 100644
index 0000000000..40590e71c4
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel}
+import org.apache.spark.sql.DataFrame
+
+private[r] class AFTSurvivalRegressionWrapper private (
+ pipeline: PipelineModel,
+ features: Array[String]) {
+
+ private val aftModel: AFTSurvivalRegressionModel =
+ pipeline.stages(1).asInstanceOf[AFTSurvivalRegressionModel]
+
+ lazy val rCoefficients: Array[Double] = if (aftModel.getFitIntercept) {
+ Array(aftModel.intercept) ++ aftModel.coefficients.toArray ++ Array(math.log(aftModel.scale))
+ } else {
+ aftModel.coefficients.toArray ++ Array(math.log(aftModel.scale))
+ }
+
+ lazy val rFeatures: Array[String] = if (aftModel.getFitIntercept) {
+ Array("(Intercept)") ++ features ++ Array("Log(scale)")
+ } else {
+ features ++ Array("Log(scale)")
+ }
+
+ def transform(dataset: DataFrame): DataFrame = {
+ pipeline.transform(dataset)
+ }
+}
+
+private[r] object AFTSurvivalRegressionWrapper {
+
+ private def formulaRewrite(formula: String): (String, String) = {
+ var rewritedFormula: String = null
+ var censorCol: String = null
+
+ val regex = """Surv\(([^,]+), ([^,]+)\) ~ (.+)""".r
+ try {
+ val regex(label, censor, features) = formula
+ // TODO: Support dot operator.
+ if (features.contains(".")) {
+ throw new UnsupportedOperationException(
+ "Terms of survreg formula can not support dot operator.")
+ }
+ rewritedFormula = label.trim + "~" + features.trim
+ censorCol = censor.trim
+ } catch {
+ case e: MatchError =>
+ throw new SparkException(s"Could not parse formula: $formula")
+ }
+
+ (rewritedFormula, censorCol)
+ }
+
+
+ def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = {
+
+ val (rewritedFormula, censorCol) = formulaRewrite(formula)
+
+ val rFormula = new RFormula().setFormula(rewritedFormula)
+ val rFormulaModel = rFormula.fit(data)
+
+ // get feature names from output schema
+ val schema = rFormulaModel.transform(data).schema
+ val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
+ .attributes.get
+ val features = featureAttrs.map(_.name.get)
+
+ val aft = new AFTSurvivalRegression()
+ .setCensorCol(censorCol)
+ .setFitIntercept(rFormula.hasIntercept)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(rFormulaModel, aft))
+ .fit(data)
+
+ new AFTSurvivalRegressionWrapper(pipeline, features)
+ }
+}