aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2015-12-21 10:21:22 +0000
committerSean Owen <sowen@cloudera.com>2015-12-21 10:21:22 +0000
commitd0f695089e4627273133c5f49ef7a83c1840c8f5 (patch)
tree84fdc240575d7da468c116846c9b28e368a98112 /mllib
parentce1798b3af8de326bf955b51ed955a924b019b4e (diff)
downloadspark-d0f695089e4627273133c5f49ef7a83c1840c8f5.tar.gz
spark-d0f695089e4627273133c5f49ef7a83c1840c8f5.tar.bz2
spark-d0f695089e4627273133c5f49ef7a83c1840c8f5.zip
[SPARK-12349][ML] Make spark.ml PCAModel load backwards compatible
Only load explainedVariance in PCAModel if it was written with Spark > 1.6.x jkbradley is this kind of what you had in mind? Author: Sean Owen <sowen@cloudera.com> Closes #10327 from srowen/SPARK-12349.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala33
1 files changed, 28 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 53d33ea2b8..759be813ee 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -167,14 +167,37 @@ object PCAModel extends MLReadable[PCAModel] {
private val className = classOf[PCAModel].getName
+ /**
+ * Loads a [[PCAModel]] from data located at the input path. Note that the model includes an
+ * `explainedVariance` member that is not recorded by Spark 1.6 and earlier. A model
+ * can be loaded from such older data but will have an empty vector for
+ * `explainedVariance`.
+ *
+ * @param path path to serialized model data
+ * @return a [[PCAModel]]
+ */
override def load(path: String): PCAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ // explainedVariance field is not present in Spark <= 1.6
+ val versionRegex = "([0-9]+)\\.([0-9])+.*".r
+ val hasExplainedVariance = metadata.sparkVersion match {
+ case versionRegex(major, minor) =>
+ (major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6))
+ case _ => false
+ }
+
val dataPath = new Path(path, "data").toString
- val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
- sqlContext.read.parquet(dataPath)
- .select("pc", "explainedVariance")
- .head()
- val model = new PCAModel(metadata.uid, pc, explainedVariance)
+ val model = if (hasExplainedVariance) {
+ val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
+ sqlContext.read.parquet(dataPath)
+ .select("pc", "explainedVariance")
+ .head()
+ new PCAModel(metadata.uid, pc, explainedVariance)
+ } else {
+ val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath).select("pc").head()
+ new PCAModel(metadata.uid, pc, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector])
+ }
DefaultParamsReader.getAndSetParams(model, metadata)
model
}