diff options
author | Yuhao Yang <hhbyyh@gmail.com> | 2016-06-23 20:43:19 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-06-23 20:43:19 -0700 |
commit | 14bc5a7f36bed19cd714a4c725a83feaccac3468 (patch) | |
tree | 8df540b45fef2526e891b4ce54c04185ffc78304 /mllib/src/main/scala | |
parent | 6a3c6276f5cef26b0a4fef44c8ad99bbecfe006d (diff) | |
download | spark-14bc5a7f36bed19cd714a4c725a83feaccac3468.tar.gz spark-14bc5a7f36bed19cd714a4c725a83feaccac3468.tar.bz2 spark-14bc5a7f36bed19cd714a4c725a83feaccac3468.zip |
[SPARK-16177][ML] model loading backward compatibility for ml.regression
## What changes were proposed in this pull request?
jira: https://issues.apache.org/jira/browse/SPARK-16177
model loading backward compatibility for ml.regression
## How was this patch tested?
existing ut and manual test for loading 1.6 models.
Author: Yuhao Yang <hhbyyh@gmail.com>
Closes #13879 from hhbyyh/regreComp.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala | 9 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala | 8 |
2 files changed, 10 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 2dbac49ccf..7c51845a25 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -33,6 +33,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -389,10 +390,10 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) - .select("coefficients", "intercept", "scale").head() - val coefficients = data.getAs[Vector](0) - val intercept = data.getDouble(1) - val scale = data.getDouble(2) + val Row(coefficients: Vector, intercept: Double, scale: Double) = + MLUtils.convertVectorColumnsToML(data, "coefficients") + .select("coefficients", "intercept", "scale") + .head() val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) DefaultParamsReader.getAndSetParams(model, metadata) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 2723f74724..0a4d98cab6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -39,6 +39,7 @@ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -500,9 +501,10 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { val dataPath = new Path(path, "data").toString val data = sparkSession.read.format("parquet").load(dataPath) - .select("intercept", "coefficients").head() - val intercept = data.getDouble(0) - val coefficients = data.getAs[Vector](1) + val Row(intercept: Double, coefficients: Vector) = + MLUtils.convertVectorColumnsToML(data, "coefficients") + .select("intercept", "coefficients") + .head() val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) DefaultParamsReader.getAndSetParams(model, metadata) |