aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2016-06-23 20:43:19 -0700
committerXiangrui Meng <meng@databricks.com>2016-06-23 20:43:19 -0700
commit14bc5a7f36bed19cd714a4c725a83feaccac3468 (patch)
tree8df540b45fef2526e891b4ce54c04185ffc78304 /mllib
parent6a3c6276f5cef26b0a4fef44c8ad99bbecfe006d (diff)
downloadspark-14bc5a7f36bed19cd714a4c725a83feaccac3468.tar.gz
spark-14bc5a7f36bed19cd714a4c725a83feaccac3468.tar.bz2
spark-14bc5a7f36bed19cd714a4c725a83feaccac3468.zip
[SPARK-16177][ML] model loading backward compatibility for ml.regression
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-16177 model loading backward compatibility for ml.regression ## How was this patch tested? existing ut and manual test for loading 1.6 models. Author: Yuhao Yang <hhbyyh@gmail.com> Closes #13879 from hhbyyh/regreComp.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala8
2 files changed, 10 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 2dbac49ccf..7c51845a25 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -33,6 +33,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
@@ -389,10 +390,10 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
- .select("coefficients", "intercept", "scale").head()
- val coefficients = data.getAs[Vector](0)
- val intercept = data.getDouble(1)
- val scale = data.getDouble(2)
+ val Row(coefficients: Vector, intercept: Double, scale: Double) =
+ MLUtils.convertVectorColumnsToML(data, "coefficients")
+ .select("coefficients", "intercept", "scale")
+ .head()
val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale)
DefaultParamsReader.getAndSetParams(model, metadata)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 2723f74724..0a4d98cab6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -39,6 +39,7 @@ import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
@@ -500,9 +501,10 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)
- .select("intercept", "coefficients").head()
- val intercept = data.getDouble(0)
- val coefficients = data.getAs[Vector](1)
+ val Row(intercept: Double, coefficients: Vector) =
+ MLUtils.convertVectorColumnsToML(data, "coefficients")
+ .select("intercept", "coefficients")
+ .head()
val model = new LinearRegressionModel(metadata.uid, coefficients, intercept)
DefaultParamsReader.getAndSetParams(model, metadata)