From 1eda2f10d9f7add319e5b271488045c44ea30c03 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 29 Apr 2016 16:46:25 -0700 Subject: [SPARK-14646][ML] Modified Kmeans to store cluster centers with one per row ## What changes were proposed in this pull request? Modified Kmeans to store cluster centers with one per row ## How was this patch tested? Existing tests Author: Joseph K. Bradley Closes #12792 from jkbradley/kmeans-save-fix. --- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index bf2ab98673..7c9ac02521 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -169,18 +169,21 @@ object KMeansModel extends MLReadable[KMeansModel] { @Since("1.6.0") override def load(path: String): KMeansModel = super.load(path) + /** Helper class for storing model data */ + private case class Data(clusterIdx: Int, clusterCenter: Vector) + /** [[MLWriter]] instance for [[KMeansModel]] */ private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { - private case class Data(clusterCenters: Array[Vector]) - override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc) // Save model data: cluster centers - val data = Data(instance.clusterCenters) + val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => + Data(idx, center) + } val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sqlContext.createDataFrame(data).repartition(1).write.parquet(dataPath) } } @@ -190,11 +193,15 @@ object KMeansModel extends MLReadable[KMeansModel] { private val className = classOf[KMeansModel].getName override def load(path: String): KMeansModel = { + // Import implicits for Dataset Encoder + val sqlContext = super.sqlContext + import sqlContext.implicits._ + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head() - val clusterCenters = data.getAs[Seq[Vector]](0).toArray + val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data] + val clusterCenters = data.collect().sortBy(_.clusterIdx).map(_.clusterCenter) val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) -- cgit v1.2.3