aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-04-29 16:46:25 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-29 16:46:25 -0700
commit1eda2f10d9f7add319e5b271488045c44ea30c03 (patch)
treeb5cc1996626f6ffc6689c319dfa4d80ac0f9dc9f
parentd33e3d572ed7143f151f9c96fd08407f8de340f4 (diff)
downloadspark-1eda2f10d9f7add319e5b271488045c44ea30c03.tar.gz
spark-1eda2f10d9f7add319e5b271488045c44ea30c03.tar.bz2
spark-1eda2f10d9f7add319e5b271488045c44ea30c03.zip
[SPARK-14646][ML] Modified Kmeans to store cluster centers with one per row
## What changes were proposed in this pull request? Modified Kmeans to store cluster centers with one per row ## How was this patch tested? Existing tests Author: Joseph K. Bradley <joseph@databricks.com> Closes #12792 from jkbradley/kmeans-save-fix.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala19
1 files changed, 13 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index bf2ab98673..7c9ac02521 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -169,18 +169,21 @@ object KMeansModel extends MLReadable[KMeansModel] {
@Since("1.6.0")
override def load(path: String): KMeansModel = super.load(path)
+ /** Helper class for storing model data */
+ private case class Data(clusterIdx: Int, clusterCenter: Vector)
+
/** [[MLWriter]] instance for [[KMeansModel]] */
private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
- private case class Data(clusterCenters: Array[Vector])
-
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: cluster centers
- val data = Data(instance.clusterCenters)
+ val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) =>
+ Data(idx, center)
+ }
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sqlContext.createDataFrame(data).repartition(1).write.parquet(dataPath)
}
}
@@ -190,11 +193,15 @@ object KMeansModel extends MLReadable[KMeansModel] {
private val className = classOf[KMeansModel].getName
override def load(path: String): KMeansModel = {
+ // Import implicits for Dataset Encoder
+ val sqlContext = super.sqlContext
+ import sqlContext.implicits._
+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head()
- val clusterCenters = data.getAs[Seq[Vector]](0).toArray
+ val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data]
+ val clusterCenters = data.collect().sortBy(_.clusterIdx).map(_.clusterCenter)
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
DefaultParamsReader.getAndSetParams(model, metadata)