diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index bf2ab98673..7c9ac02521 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -169,18 +169,21 @@ object KMeansModel extends MLReadable[KMeansModel] { @Since("1.6.0") override def load(path: String): KMeansModel = super.load(path) + /** Helper class for storing model data */ + private case class Data(clusterIdx: Int, clusterCenter: Vector) + /** [[MLWriter]] instance for [[KMeansModel]] */ private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { - private case class Data(clusterCenters: Array[Vector]) - override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc) // Save model data: cluster centers - val data = Data(instance.clusterCenters) + val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => + Data(idx, center) + } val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sqlContext.createDataFrame(data).repartition(1).write.parquet(dataPath) } } @@ -190,11 +193,15 @@ object KMeansModel extends MLReadable[KMeansModel] { private val className = classOf[KMeansModel].getName override def load(path: String): KMeansModel = { + // Import implicits for Dataset Encoder + val sqlContext = super.sqlContext + import sqlContext.implicits._ + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head() - val clusterCenters = data.getAs[Seq[Vector]](0).toArray + val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data] + val clusterCenters = data.collect().sortBy(_.clusterIdx).map(_.clusterCenter) val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) |