aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-05-19 10:25:33 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-05-19 10:25:33 -0700
commit1052d3644d7eb0e784eb883293ce63a352a3b123 (patch)
treed23b1f89af5f3bef652ff6105b00fff470e32cab /mllib/src/main/scala
parent3facca5152e685d9c7da96bff5102169740a4a06 (diff)
downloadspark-1052d3644d7eb0e784eb883293ce63a352a3b123.tar.gz
spark-1052d3644d7eb0e784eb883293ce63a352a3b123.tar.bz2
spark-1052d3644d7eb0e784eb883293ce63a352a3b123.zip
[SPARK-15362][ML] Make spark.ml KMeansModel load backwards compatible
## What changes were proposed in this pull request? [SPARK-14646](https://issues.apache.org/jira/browse/SPARK-14646) makes ```KMeansModel``` store the cluster centers one per row. ```KMeansModel.load()``` method needs to be updated in order to load models saved with Spark 1.6. ## How was this patch tested? Since ```save/load``` is ```Experimental``` for 1.6, I think offline test for backwards compatibility is enough. Author: Yanbo Liang <ybliang8@gmail.com> Closes #13149 from yanboliang/spark-15362.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala22
1 files changed, 17 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 41c0aec0ec..986f7e0fb0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -185,6 +185,12 @@ object KMeansModel extends MLReadable[KMeansModel] {
/** Helper class for storing model data */
private case class Data(clusterIdx: Int, clusterCenter: Vector)
+ /**
+ * We store all cluster centers in a single row and use this class to store model data by
+ * Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility.
+ */
+ private case class OldData(clusterCenters: Array[OldVector])
+
/** [[MLWriter]] instance for [[KMeansModel]] */
private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
@@ -211,13 +217,19 @@ object KMeansModel extends MLReadable[KMeansModel] {
import sqlContext.implicits._
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
-
val dataPath = new Path(path, "data").toString
- val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data]
- val clusterCenters = data.collect().sortBy(_.clusterIdx).map(_.clusterCenter)
- val model = new KMeansModel(metadata.uid,
- new MLlibKMeansModel(clusterCenters.map(OldVectors.fromML)))
+ val versionRegex = "([0-9]+)\\.(.+)".r
+ val versionRegex(major, _) = metadata.sparkVersion
+
+ val clusterCenters = if (major.toInt >= 2) {
+ val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data]
+ data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
+ } else {
+ // Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.
+ sqlContext.read.parquet(dataPath).as[OldData].head().clusterCenters
+ }
+ val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
DefaultParamsReader.getAndSetParams(model, metadata)
model
}