aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2016-06-23 11:00:00 -0700
committerXiangrui Meng <meng@databricks.com>2016-06-23 11:00:00 -0700
commit60398dabc50d402bbab4190fbe94ebed6d3a48dc (patch)
treea189f8ab78eb58304a6151981d896dd563f4dca9 /mllib
parentd85bb10ce49926b8b661bd2cb97392205742fc14 (diff)
downloadspark-60398dabc50d402bbab4190fbe94ebed6d3a48dc.tar.gz
spark-60398dabc50d402bbab4190fbe94ebed6d3a48dc.tar.bz2
spark-60398dabc50d402bbab4190fbe94ebed6d3a48dc.zip
[SPARK-16130][ML] model loading backward compatibility for ml.classfication.LogisticRegression
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-16130 model loading backward compatibility for ml.classfication.LogisticRegression ## How was this patch tested? existing ut and manual test for loading old models. Author: Yuhao Yang <hhbyyh@gmail.com> Closes #13841 from hhbyyh/lrcomp.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala10
1 files changed, 5 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index be69d46eeb..9c9f5ced4e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -674,12 +674,12 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)
- .select("numClasses", "numFeatures", "intercept", "coefficients").head()
+
// We will need numClasses, numFeatures in the future for multinomial logreg support.
- // val numClasses = data.getInt(0)
- // val numFeatures = data.getInt(1)
- val intercept = data.getDouble(2)
- val coefficients = data.getAs[Vector](3)
+ val Row(numClasses: Int, numFeatures: Int, intercept: Double, coefficients: Vector) =
+ MLUtils.convertVectorColumnsToML(data, "coefficients")
+ .select("numClasses", "numFeatures", "intercept", "coefficients")
+ .head()
val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept)
DefaultParamsReader.getAndSetParams(model, metadata)