From 1052d3644d7eb0e784eb883293ce63a352a3b123 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 19 May 2016 10:25:33 -0700 Subject: [SPARK-15362][ML] Make spark.ml KMeansModel load backwards compatible ## What changes were proposed in this pull request? [SPARK-14646](https://issues.apache.org/jira/browse/SPARK-14646) makes ```KMeansModel``` store the cluster centers one per row. ```KMeansModel.load()``` method needs to be updated in order to load models saved with Spark 1.6. ## How was this patch tested? Since ```save/load``` is ```Experimental``` for 1.6, I think offline test for backwards compatibility is enough. Author: Yanbo Liang Closes #13149 from yanboliang/spark-15362. --- .../org/apache/spark/ml/clustering/KMeans.scala | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) (limited to 'mllib/src/main/scala') diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 41c0aec0ec..986f7e0fb0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -185,6 +185,12 @@ object KMeansModel extends MLReadable[KMeansModel] { /** Helper class for storing model data */ private case class Data(clusterIdx: Int, clusterCenter: Vector) + /** + * We store all cluster centers in a single row and use this class to store model data by + * Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility. + */ + private case class OldData(clusterCenters: Array[OldVector]) + /** [[MLWriter]] instance for [[KMeansModel]] */ private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { @@ -211,13 +217,19 @@ object KMeansModel extends MLReadable[KMeansModel] { import sqlContext.implicits._ val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - val dataPath = new Path(path, "data").toString - val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data] - val clusterCenters = data.collect().sortBy(_.clusterIdx).map(_.clusterCenter) - val model = new KMeansModel(metadata.uid, - new MLlibKMeansModel(clusterCenters.map(OldVectors.fromML))) + val versionRegex = "([0-9]+)\\.(.+)".r + val versionRegex(major, _) = metadata.sparkVersion + + val clusterCenters = if (major.toInt >= 2) { + val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data] + data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) + } else { + // Loads KMeansModel stored with the old format used by Spark 1.6 and earlier. + sqlContext.read.parquet(dataPath).as[OldData].head().clusterCenters + } + val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) model } -- cgit v1.2.3