aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-06-28 19:53:07 -0700
committerXiangrui Meng <meng@databricks.com>2016-06-28 19:53:07 -0700
commit0df5ce1bc1387a58b33cd185008f4022bd3dcc69 (patch)
tree3810cd91225bc780259dac5062c4fd3f19f2756a /mllib
parent363bcedeea40fe3f1a92271b96af2acba63e058c (diff)
downloadspark-0df5ce1bc1387a58b33cd185008f4022bd3dcc69.tar.gz
spark-0df5ce1bc1387a58b33cd185008f4022bd3dcc69.tar.bz2
spark-0df5ce1bc1387a58b33cd185008f4022bd3dcc69.zip
[SPARK-16245][ML] model loading backward compatibility for ml.feature.PCA
## What changes were proposed in this pull request? model loading backward compatibility for ml.feature.PCA. ## How was this patch tested? existing ut and manual test for loading models saved by Spark 1.6. Author: Yanbo Liang <ybliang8@gmail.com> Closes #13937 from yanboliang/spark-16245.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala18
1 files changed, 8 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 72167b50e3..ef8b08545d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -206,24 +206,22 @@ object PCAModel extends MLReadable[PCAModel] {
override def load(path: String): PCAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
- // explainedVariance field is not present in Spark <= 1.6
- val versionRegex = "([0-9]+)\\.([0-9]+).*".r
- val hasExplainedVariance = metadata.sparkVersion match {
- case versionRegex(major, minor) =>
- major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6)
- case _ => false
- }
+ val versionRegex = "([0-9]+)\\.(.+)".r
+ val versionRegex(major, _) = metadata.sparkVersion
val dataPath = new Path(path, "data").toString
- val model = if (hasExplainedVariance) {
+ val model = if (major.toInt >= 2) {
val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
sparkSession.read.parquet(dataPath)
.select("pc", "explainedVariance")
.head()
new PCAModel(metadata.uid, pc, explainedVariance)
} else {
- val Row(pc: DenseMatrix) = sparkSession.read.parquet(dataPath).select("pc").head()
- new PCAModel(metadata.uid, pc, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector])
+ // pc field is the old matrix format in Spark <= 1.6
+ // explainedVariance field is not present in Spark <= 1.6
+ val Row(pc: OldDenseMatrix) = sparkSession.read.parquet(dataPath).select("pc").head()
+ new PCAModel(metadata.uid, pc.asML,
+ Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector])
}
DefaultParamsReader.getAndSetParams(model, metadata)
model