diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-06-28 19:53:07 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-06-28 19:53:07 -0700 |
commit | 0df5ce1bc1387a58b33cd185008f4022bd3dcc69 (patch) | |
tree | 3810cd91225bc780259dac5062c4fd3f19f2756a /mllib/src | |
parent | 363bcedeea40fe3f1a92271b96af2acba63e058c (diff) | |
download | spark-0df5ce1bc1387a58b33cd185008f4022bd3dcc69.tar.gz spark-0df5ce1bc1387a58b33cd185008f4022bd3dcc69.tar.bz2 spark-0df5ce1bc1387a58b33cd185008f4022bd3dcc69.zip |
[SPARK-16245][ML] model loading backward compatibility for ml.feature.PCA
## What changes were proposed in this pull request?
model loading backward compatibility for ml.feature.PCA.
## How was this patch tested?
existing ut and manual test for loading models saved by Spark 1.6.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #13937 from yanboliang/spark-16245.
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala | 18 |
1 files changed, 8 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 72167b50e3..ef8b08545d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -206,24 +206,22 @@ object PCAModel extends MLReadable[PCAModel] { override def load(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - // explainedVariance field is not present in Spark <= 1.6 - val versionRegex = "([0-9]+)\\.([0-9]+).*".r - val hasExplainedVariance = metadata.sparkVersion match { - case versionRegex(major, minor) => - major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6) - case _ => false - } + val versionRegex = "([0-9]+)\\.(.+)".r + val versionRegex(major, _) = metadata.sparkVersion val dataPath = new Path(path, "data").toString - val model = if (hasExplainedVariance) { + val model = if (major.toInt >= 2) { val Row(pc: DenseMatrix, explainedVariance: DenseVector) = sparkSession.read.parquet(dataPath) .select("pc", "explainedVariance") .head() new PCAModel(metadata.uid, pc, explainedVariance) } else { - val Row(pc: DenseMatrix) = sparkSession.read.parquet(dataPath).select("pc").head() - new PCAModel(metadata.uid, pc, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) + // pc field is the old matrix format in Spark <= 1.6 + // explainedVariance field is not present in Spark <= 1.6 + val Row(pc: OldDenseMatrix) = sparkSession.read.parquet(dataPath).select("pc").head() + new PCAModel(metadata.uid, pc.asML, + Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) } DefaultParamsReader.getAndSetParams(model, metadata) model |