aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala22
1 files changed, 17 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 41c0aec0ec..986f7e0fb0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -185,6 +185,12 @@ object KMeansModel extends MLReadable[KMeansModel] {
/** Helper class for storing model data */
private case class Data(clusterIdx: Int, clusterCenter: Vector)
+ /**
+ * We store all cluster centers in a single row and use this class to store model data by
+ * Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility.
+ */
+ private case class OldData(clusterCenters: Array[OldVector])
+
/** [[MLWriter]] instance for [[KMeansModel]] */
private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
@@ -211,13 +217,19 @@ object KMeansModel extends MLReadable[KMeansModel] {
import sqlContext.implicits._
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
-
val dataPath = new Path(path, "data").toString
- val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data]
- val clusterCenters = data.collect().sortBy(_.clusterIdx).map(_.clusterCenter)
- val model = new KMeansModel(metadata.uid,
- new MLlibKMeansModel(clusterCenters.map(OldVectors.fromML)))
+ val versionRegex = "([0-9]+)\\.(.+)".r
+ val versionRegex(major, _) = metadata.sparkVersion
+
+ val clusterCenters = if (major.toInt >= 2) {
+ val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data]
+ data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
+ } else {
+ // Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.
+ sqlContext.read.parquet(dataPath).as[OldData].head().clusterCenters
+ }
+ val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
DefaultParamsReader.getAndSetParams(model, metadata)
model
}