diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala | 9 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala | 8 |
2 files changed, 10 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 2dbac49ccf..7c51845a25 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -33,6 +33,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -389,10 +390,10 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) - .select("coefficients", "intercept", "scale").head() - val coefficients = data.getAs[Vector](0) - val intercept = data.getDouble(1) - val scale = data.getDouble(2) + val Row(coefficients: Vector, intercept: Double, scale: Double) = + MLUtils.convertVectorColumnsToML(data, "coefficients") + .select("coefficients", "intercept", "scale") + .head() val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) DefaultParamsReader.getAndSetParams(model, metadata) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 2723f74724..0a4d98cab6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -39,6 +39,7 @@ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -500,9 +501,10 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { val dataPath = new Path(path, "data").toString val data = sparkSession.read.format("parquet").load(dataPath) - .select("intercept", "coefficients").head() - val intercept = data.getDouble(0) - val coefficients = data.getAs[Vector](1) + val Row(intercept: Double, coefficients: Vector) = + MLUtils.convertVectorColumnsToML(data, "coefficients") + .select("intercept", "coefficients") + .head() val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) DefaultParamsReader.getAndSetParams(model, metadata) |