aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala8
2 files changed, 10 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 2dbac49ccf..7c51845a25 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -33,6 +33,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
@@ -389,10 +390,10 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
- .select("coefficients", "intercept", "scale").head()
- val coefficients = data.getAs[Vector](0)
- val intercept = data.getDouble(1)
- val scale = data.getDouble(2)
+ val Row(coefficients: Vector, intercept: Double, scale: Double) =
+ MLUtils.convertVectorColumnsToML(data, "coefficients")
+ .select("coefficients", "intercept", "scale")
+ .head()
val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale)
DefaultParamsReader.getAndSetParams(model, metadata)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 2723f74724..0a4d98cab6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -39,6 +39,7 @@ import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
@@ -500,9 +501,10 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)
- .select("intercept", "coefficients").head()
- val intercept = data.getDouble(0)
- val coefficients = data.getAs[Vector](1)
+ val Row(intercept: Double, coefficients: Vector) =
+ MLUtils.convertVectorColumnsToML(data, "coefficients")
+ .select("intercept", "coefficients")
+ .head()
val model = new LinearRegressionModel(metadata.uid, coefficients, intercept)
DefaultParamsReader.getAndSetParams(model, metadata)