From 14bc5a7f36bed19cd714a4c725a83feaccac3468 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 23 Jun 2016 20:43:19 -0700 Subject: [SPARK-16177][ML] model loading backward compatibility for ml.regression ## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-16177 model loading backward compatibility for ml.regression ## How was this patch tested? existing ut and manual test for loading 1.6 models. Author: Yuhao Yang Closes #13879 from hhbyyh/regreComp. --- .../org/apache/spark/ml/regression/AFTSurvivalRegression.scala | 9 +++++---- .../scala/org/apache/spark/ml/regression/LinearRegression.scala | 8 +++++--- 2 files changed, 10 insertions(+), 7 deletions(-) (limited to 'mllib/src') diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 2dbac49ccf..7c51845a25 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -33,6 +33,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -389,10 +390,10 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) - .select("coefficients", "intercept", "scale").head() - val coefficients = data.getAs[Vector](0) - val intercept = data.getDouble(1) - val scale = data.getDouble(2) + val Row(coefficients: Vector, intercept: Double, scale: Double) = + MLUtils.convertVectorColumnsToML(data, "coefficients") + .select("coefficients", "intercept", "scale") + .head() val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) DefaultParamsReader.getAndSetParams(model, metadata) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 2723f74724..0a4d98cab6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -39,6 +39,7 @@ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -500,9 +501,10 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { val dataPath = new Path(path, "data").toString val data = sparkSession.read.format("parquet").load(dataPath) - .select("intercept", "coefficients").head() - val intercept = data.getDouble(0) - val coefficients = data.getAs[Vector](1) + val Row(intercept: Double, coefficients: Vector) = + MLUtils.convertVectorColumnsToML(data, "coefficients") + .select("intercept", "coefficients") + .head() val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) DefaultParamsReader.getAndSetParams(model, metadata) -- cgit v1.2.3