diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 40590e71c4..7835468626 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} private[r] class AFTSurvivalRegressionWrapper private ( pipeline: PipelineModel, @@ -43,8 +43,8 @@ private[r] class AFTSurvivalRegressionWrapper private ( features ++ Array("Log(scale)") } - def transform(dataset: DataFrame): DataFrame = { - pipeline.transform(dataset) + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(aftModel.getFeaturesCol) } } |