diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2016-04-29 16:46:25 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-29 16:46:25 -0700 |
commit | 1eda2f10d9f7add319e5b271488045c44ea30c03 (patch) | |
tree | b5cc1996626f6ffc6689c319dfa4d80ac0f9dc9f | |
parent | d33e3d572ed7143f151f9c96fd08407f8de340f4 (diff) | |
download | spark-1eda2f10d9f7add319e5b271488045c44ea30c03.tar.gz spark-1eda2f10d9f7add319e5b271488045c44ea30c03.tar.bz2 spark-1eda2f10d9f7add319e5b271488045c44ea30c03.zip |
[SPARK-14646][ML] Modified Kmeans to store cluster centers with one per row
## What changes were proposed in this pull request?
Modified Kmeans to store cluster centers with one per row
## How was this patch tested?
Existing tests
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #12792 from jkbradley/kmeans-save-fix.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index bf2ab98673..7c9ac02521 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -169,18 +169,21 @@ object KMeansModel extends MLReadable[KMeansModel] { @Since("1.6.0") override def load(path: String): KMeansModel = super.load(path) + /** Helper class for storing model data */ + private case class Data(clusterIdx: Int, clusterCenter: Vector) + /** [[MLWriter]] instance for [[KMeansModel]] */ private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { - private case class Data(clusterCenters: Array[Vector]) - override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc) // Save model data: cluster centers - val data = Data(instance.clusterCenters) + val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => + Data(idx, center) + } val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sqlContext.createDataFrame(data).repartition(1).write.parquet(dataPath) } } @@ -190,11 +193,15 @@ object KMeansModel extends MLReadable[KMeansModel] { private val className = classOf[KMeansModel].getName override def load(path: String): KMeansModel = { + // Import implicits for Dataset Encoder + val sqlContext = super.sqlContext + import sqlContext.implicits._ + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head() - val clusterCenters = data.getAs[Seq[Vector]](0).toArray + val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data] + val clusterCenters = data.collect().sortBy(_.clusterIdx).map(_.clusterCenter) val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) |